mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Support gather on bf16 for metal. (#2035)
This commit is contained in:
@ -1042,6 +1042,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||||
|
(DType::U32, DType::BF16) => "gather_u32_bf16",
|
||||||
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
@ -207,6 +207,9 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
|||||||
|
|
||||||
GATHER_OP(gather_u32_f32, uint, float)
|
GATHER_OP(gather_u32_f32, uint, float)
|
||||||
GATHER_OP(gather_u32_f16, uint, half)
|
GATHER_OP(gather_u32_f16, uint, half)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
GATHER_OP(gather_u32_bf16, uint, bfloat)
|
||||||
|
#endif
|
||||||
|
|
||||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||||
|
Reference in New Issue
Block a user