From e8ee253ee0766c33ac69f08bb0bcd6601f47ca6f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 18 Dec 2023 11:01:18 +0100 Subject: [PATCH] Missing cast. --- candle-core/src/metal_backend.rs | 2 ++ candle-metal-kernels/src/cast.metal | 1 + 2 files changed, 3 insertions(+) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 21a8967b..0af11a3d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -578,6 +578,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U8, DType::U32) => "cast_u8_u32", + (DType::U8, DType::F32) => "cast_u8_f32", (DType::F32, DType::F16) => "cast_f32_f16", (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), @@ -598,6 +599,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::U8, DType::F32) => "cast_u8_f32_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 4398e9d4..8481389d 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -48,6 +48,7 @@ kernel void FN_NAME_STRIDED( \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f32_f16, cast_f32_f16_strided, float, half)