mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
add test for index add and add missing match statements (#1862)
This commit is contained in:
@ -1242,9 +1242,29 @@ impl BackendStorage for MetalStorage {
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::I64, DType::BF16) => "ia_i64_bf16",
|
||||
(DType::I64, DType::F16) => "ia_i64_f16",
|
||||
(DType::I64, DType::F32) => "ia_i64_f32",
|
||||
(DType::I64, DType::I64) => "ia_i64_i64",
|
||||
(DType::I64, DType::U32) => "ia_i64_u32",
|
||||
(DType::I64, DType::U8) => "ia_i64_u8",
|
||||
|
||||
(DType::U32, DType::BF16) => "ia_u32_bf16",
|
||||
(DType::U32, DType::F16) => "ia_u32_f16",
|
||||
(DType::U32, DType::F32) => "ia_u32_f32",
|
||||
(DType::U32, DType::I64) => "ia_u32_i64",
|
||||
(DType::U32, DType::U32) => "ia_u32_u32",
|
||||
(DType::U32, DType::U8) => "ia_u32_u8",
|
||||
|
||||
(DType::U8, DType::BF16) => "ia_u8_bf16",
|
||||
(DType::U8, DType::F16) => "ia_u8_f16",
|
||||
(DType::U8, DType::F32) => "ia_u8_f32",
|
||||
(DType::U8, DType::I64) => "ia_u8_i64",
|
||||
(DType::U8, DType::U32) => "ia_u8_u32",
|
||||
(DType::U8, DType::U8) => "ia_u8_u8",
|
||||
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "index-add ids should be u32",
|
||||
msg: "index-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
|
Reference in New Issue
Block a user