Add support for index u8/i64 and input f16/bf16 scatter-add on metal (#1849)

* add support and tests for scatter add on metal

* add support for all datatypes
This commit is contained in:
Thomas Santerre
2024-03-17 03:09:43 -04:00
committed by GitHub
parent 74bf6994b1
commit db8b24ae92
3 changed files with 123 additions and 2 deletions

View File

@ -1128,7 +1128,15 @@ impl BackendStorage for MetalStorage {
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "sa_u8_f32",
(DType::U8, DType::F16) => "sa_u8_f16",
(DType::U8, DType::BF16) => "sa_u8_bf16",
(DType::U32, DType::F32) => "sa_u32_f32",
(DType::U32, DType::F16) => "sa_u32_f16",
(DType::U32, DType::BF16) => "sa_u32_bf16",
(DType::I64, DType::F32) => "sa_i64_f32",
(DType::I64, DType::F16) => "sa_i64_f16",
(DType::I64, DType::BF16) => "sa_i64_bf16",
_ => Err(MetalError::UnexpectedDType {
msg: "scatter-add ids should be u8/u32/i64",
expected: DType::U32,