mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Adding bfloat16 support for the cast kernels. (#1520)
This commit is contained in:
@ -596,6 +596,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||||
|
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||||
|
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||||
(left, right) => {
|
(left, right) => {
|
||||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||||
}
|
}
|
||||||
@ -622,6 +624,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||||
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||||
|
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
||||||
|
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
||||||
(left, right) => {
|
(left, right) => {
|
||||||
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
||||||
}
|
}
|
||||||
|
@ -59,4 +59,6 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||||
|
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user