From fa3ea98ba92835960fdd825a5b4dda30ef2baaa4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jan 2024 12:12:56 +0100 Subject: [PATCH] Adding bfloat16 support for the cast kernels. (#1520) --- candle-core/src/metal_backend.rs | 4 ++++ candle-metal-kernels/src/cast.metal | 2 ++ 2 files changed, 6 insertions(+) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index e168c24b..c1c4aa4b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -596,6 +596,8 @@ impl BackendStorage for MetalStorage { (DType::F32, DType::F16) => "cast_f32_f16", (DType::F16, DType::F32) => "cast_f16_f32", (DType::I64, DType::F32) => "cast_i64_f32", + (DType::F32, DType::BF16) => "cast_f32_bf16", + (DType::BF16, DType::F32) => "cast_bf16_f32", (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") } @@ -622,6 +624,8 @@ impl BackendStorage for MetalStorage { (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", (DType::I64, DType::F32) => "cast_i64_f32_strided", + (DType::F32, DType::BF16) => "cast_f32_bf16_strided", + (DType::BF16, DType::F32) => "cast_bf16_f32_strided", (left, right) => { crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented") } diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 3baefcc2..e9ab17b1 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -59,4 +59,6 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif #if __METAL_VERSION__ >= 310 +CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) +CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) #endif