From 8696cf64947a7f3b712297426078dcf6ab0d199e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 1 Aug 2024 09:03:11 +0100 Subject: [PATCH] Enable the affine kernel for u8/u32. (#2376) --- candle-core/src/metal_backend/mod.rs | 2 ++ candle-metal-kernels/src/affine.metal | 2 ++ 2 files changed, 4 insertions(+) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index fa83692d..58be5502 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -119,6 +119,8 @@ impl BackendStorage for MetalStorage { DType::F32 => "affine_f32", DType::F16 => "affine_f16", DType::BF16 => "affine_bf16", + DType::U8 => "affine_u8", + DType::U32 => "affine_u32", dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 76c0365a..cbbb03e2 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -109,6 +109,8 @@ kernel void FN_NAME##_strided( \ } \ +AFFINE(affine_u8, uint8_t) +AFFINE(affine_u32, uint32_t) AFFINE(affine_f32, float) AFFINE(affine_f16, half) POWF(powf_f32, float)