diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 50149a9d..158eb8e0 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1042,6 +1042,7 @@ impl BackendStorage for MetalStorage { let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "gather_u32_f32", (DType::U32, DType::F16) => "gather_u32_f16", + (DType::U32, DType::BF16) => "gather_u32_bf16", (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), }; let command_buffer = self.device.command_buffer()?; diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 762b42be..9eee97ca 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -207,6 +207,9 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat) GATHER_OP(gather_u32_f32, uint, float) GATHER_OP(gather_u32_f16, uint, half) +#if defined(__HAVE_BFLOAT__) +GATHER_OP(gather_u32_bf16, uint, bfloat) +#endif SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)