diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index d8518b3e..b4a490cd 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "affine_float", - DType::F16 => "affine_half", + DType::F32 => "affine_f32", + DType::F16 => "affine_f16", dtype => crate::bail!("Affine {dtype:?}"), }; candle_metal_kernels::call_affine( @@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "affine_float_strided", - DType::F16 => "affine_half_strided", + DType::F32 => "affine_f32_strided", + DType::F16 => "affine_f16_strided", dtype => crate::bail!("Affine {dtype:?}"), }; candle_metal_kernels::call_affine_strided( @@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "powf_float", - DType::F16 => "powf_half", + DType::F32 => "powf_f32", + DType::F16 => "powf_f16", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_powf( @@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "powf_float_strided", - DType::F16 => "powf_half_strided", + DType::F32 => "powf_f32_strided", + DType::F16 => "powf_f16_strided", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_powf_strided( @@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { let name = match self.dtype { - DType::F32 => "elu_float", - DType::F16 => "elu_half", + DType::F32 => "elu_f32", + DType::F16 => "elu_f16", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_elu( @@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let name = match self.dtype { - DType::F32 => "elu_float_strided", - DType::F16 => "elu_half_strided", + DType::F32 => "elu_f32_strided", + DType::F16 => "elu_f16_strided", dtype => crate::bail!("Powf {dtype:?}"), }; candle_metal_kernels::call_elu_strided( @@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage { // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { - (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), - (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), - (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), - (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), - (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), _ => crate::bail!("Reduce op for non float"), }; if check_empty && layout.shape().elem_count() == 0 { diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 18adb457..4166d811 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \ } \ -AFFINE(affine_float, float) -AFFINE(affine_half, half) -POWF(powf_float, float) -POWF(powf_half, half) -ELU(elu_float, float) -ELU(elu_half, half) +AFFINE(affine_f32, float) +AFFINE(affine_f16, half) +POWF(powf_f32, float) +POWF(powf_f16, half) +ELU(elu_f32, float) +ELU(elu_f16, half) #if __METAL_VERSION__ >= 310 -AFFINE(affine_bfloat, bfloat); -POWF(powf_bfloat, bfloat); -ELU(elu_bfloat, bfloat); +AFFINE(affine_bf16, bfloat); +POWF(powf_bf16, bfloat); +ELU(elu_bf16, bfloat); #endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index f18cdbb0..ea21bb34 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -52,11 +52,11 @@ kernel void FN_NAME_STRIDED( \ } #define BINARY_OP(FN, NAME) \ -BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ -BINARY(FN, half, half, NAME##_half, NAME##_half_strided); +BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); BINARY_OP(x + y, add) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 514cf33e..a23aa47c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -125,16 +125,16 @@ macro_rules! ops{ $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); } )+ pub mod copy { use super::Kernel; - pub const FLOAT: Kernel = Kernel("copy_float"); - pub const HALF: Kernel = Kernel("copy_half"); - pub const BFLOAT: Kernel = Kernel("copy_bfloat"); + pub const FLOAT: Kernel = Kernel("copy_f32"); + pub const HALF: Kernel = Kernel("copy_f16"); + pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -145,16 +145,16 @@ macro_rules! ops{ $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); } )+ pub mod copy { use super::Kernel; - pub const FLOAT: Kernel = Kernel("copy_float_strided"); - pub const HALF: Kernel = Kernel("copy_half_strided"); - pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); + pub const FLOAT: Kernel = Kernel("copy_f32_strided"); + pub const HALF: Kernel = Kernel("copy_f16_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 3633fdcf..62443660 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -71,9 +71,9 @@ kernel void NAME( \ } \ -REDUCE(x + y, fast_sum_float, float) -REDUCE(x * y, fast_mul_float, float) -REDUCE(max(x, y), fast_max_float, float) +REDUCE(x + y, fast_sum_f32, float) +REDUCE(x * y, fast_mul_f32, float) +REDUCE(max(x, y), fast_max_f32, float) #define SOFTMAX(NAME, T) \ kernel void NAME( \ @@ -142,8 +142,8 @@ kernel void NAME( } \ } \ -SOFTMAX(softmax_float, float) -SOFTMAX(softmax_half, half) +SOFTMAX(softmax_f32, float) +SOFTMAX(softmax_f16, half) #if __METAL_VERSION__ >= 310 -SOFTMAX(softmax_bfloat, bfloat) +SOFTMAX(softmax_bf16, bfloat) #endif diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 765b14a5..553bc506 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -87,11 +87,11 @@ kernel void FN_NAME_STRIDED( \ } #define UNARY_OP(NAME) \ -UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ -UNARY(NAME, half, NAME##_half, NAME##_half_strided); +UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \ +UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_UNARY_OP(NAME) \ -UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); +UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); UNARY_OP(cos) @@ -108,8 +108,8 @@ UNARY_OP(round) UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) -UNARY(id, float, copy_float, copy_float_strided) -UNARY(id, half, copy_half, copy_half_strided) +UNARY(id, float, copy_f32, copy_f32_strided) +UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided) @@ -129,5 +129,5 @@ BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) -UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) +UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index f00d8e2f..ca23f90e 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -213,9 +213,9 @@ impl candle::CustomOp1 for SoftmaxLastDim { let command_buffer = device.command_buffer(); let kernels = device.kernels(); let name = match storage.dtype() { - DType::F32 => "softmax_float", - DType::F16 => "softmax_half", - DType::BF16 => "softmax_bfloat", + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), };