Renamed all kernel names.

This commit is contained in:
Nicolas Patry
2023-12-15 11:24:47 +01:00
parent 34d83377f6
commit 26540641c1
7 changed files with 56 additions and 56 deletions

View File

@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "affine_float", DType::F32 => "affine_f32",
DType::F16 => "affine_half", DType::F16 => "affine_f16",
dtype => crate::bail!("Affine {dtype:?}"), dtype => crate::bail!("Affine {dtype:?}"),
}; };
candle_metal_kernels::call_affine( candle_metal_kernels::call_affine(
@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "affine_float_strided", DType::F32 => "affine_f32_strided",
DType::F16 => "affine_half_strided", DType::F16 => "affine_f16_strided",
dtype => crate::bail!("Affine {dtype:?}"), dtype => crate::bail!("Affine {dtype:?}"),
}; };
candle_metal_kernels::call_affine_strided( candle_metal_kernels::call_affine_strided(
@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "powf_float", DType::F32 => "powf_f32",
DType::F16 => "powf_half", DType::F16 => "powf_f16",
dtype => crate::bail!("Powf {dtype:?}"), dtype => crate::bail!("Powf {dtype:?}"),
}; };
candle_metal_kernels::call_powf( candle_metal_kernels::call_powf(
@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "powf_float_strided", DType::F32 => "powf_f32_strided",
DType::F16 => "powf_half_strided", DType::F16 => "powf_f16_strided",
dtype => crate::bail!("Powf {dtype:?}"), dtype => crate::bail!("Powf {dtype:?}"),
}; };
candle_metal_kernels::call_powf_strided( candle_metal_kernels::call_powf_strided(
@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "elu_float", DType::F32 => "elu_f32",
DType::F16 => "elu_half", DType::F16 => "elu_f16",
dtype => crate::bail!("Powf {dtype:?}"), dtype => crate::bail!("Powf {dtype:?}"),
}; };
candle_metal_kernels::call_elu( candle_metal_kernels::call_elu(
@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "elu_float_strided", DType::F32 => "elu_f32_strided",
DType::F16 => "elu_half_strided", DType::F16 => "elu_f16_strided",
dtype => crate::bail!("Powf {dtype:?}"), dtype => crate::bail!("Powf {dtype:?}"),
}; };
candle_metal_kernels::call_elu_strided( candle_metal_kernels::call_elu_strided(
@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage {
// The reduction loop requires the shared array to be properly initialized and for // The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two. // this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) { let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
_ => crate::bail!("Reduce op for non float"), _ => crate::bail!("Reduce op for non float"),
}; };
if check_empty && layout.shape().elem_count() == 0 { if check_empty && layout.shape().elem_count() == 0 {

View File

@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \
} \ } \
AFFINE(affine_float, float) AFFINE(affine_f32, float)
AFFINE(affine_half, half) AFFINE(affine_f16, half)
POWF(powf_float, float) POWF(powf_f32, float)
POWF(powf_half, half) POWF(powf_f16, half)
ELU(elu_float, float) ELU(elu_f32, float)
ELU(elu_half, half) ELU(elu_f16, half)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat); AFFINE(affine_bf16, bfloat);
POWF(powf_bfloat, bfloat); POWF(powf_bf16, bfloat);
ELU(elu_bfloat, bfloat); ELU(elu_bf16, bfloat);
#endif #endif

View File

@ -52,11 +52,11 @@ kernel void FN_NAME_STRIDED( \
} }
#define BINARY_OP(FN, NAME) \ #define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided); BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \ #define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
BINARY_OP(x + y, add) BINARY_OP(x + y, add)

View File

@ -125,16 +125,16 @@ macro_rules! ops{
$( $(
pub mod $name { pub mod $name {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
} }
)+ )+
pub mod copy { pub mod copy {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel("copy_float"); pub const FLOAT: Kernel = Kernel("copy_f32");
pub const HALF: Kernel = Kernel("copy_half"); pub const HALF: Kernel = Kernel("copy_f16");
pub const BFLOAT: Kernel = Kernel("copy_bfloat"); pub const BFLOAT: Kernel = Kernel("copy_bf16");
pub const U32: Kernel = Kernel("copy_u32"); pub const U32: Kernel = Kernel("copy_u32");
pub const U8: Kernel = Kernel("copy_u8"); pub const U8: Kernel = Kernel("copy_u8");
} }
@ -145,16 +145,16 @@ macro_rules! ops{
$( $(
pub mod $name { pub mod $name {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
} }
)+ )+
pub mod copy { pub mod copy {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel("copy_float_strided"); pub const FLOAT: Kernel = Kernel("copy_f32_strided");
pub const HALF: Kernel = Kernel("copy_half_strided"); pub const HALF: Kernel = Kernel("copy_f16_strided");
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U32: Kernel = Kernel("copy_u32_strided");
pub const U8: Kernel = Kernel("copy_u8_strided"); pub const U8: Kernel = Kernel("copy_u8_strided");
} }

View File

@ -71,9 +71,9 @@ kernel void NAME( \
} \ } \
REDUCE(x + y, fast_sum_float, float) REDUCE(x + y, fast_sum_f32, float)
REDUCE(x * y, fast_mul_float, float) REDUCE(x * y, fast_mul_f32, float)
REDUCE(max(x, y), fast_max_float, float) REDUCE(max(x, y), fast_max_f32, float)
#define SOFTMAX(NAME, T) \ #define SOFTMAX(NAME, T) \
kernel void NAME( \ kernel void NAME( \
@ -142,8 +142,8 @@ kernel void NAME(
} \ } \
} \ } \
SOFTMAX(softmax_float, float) SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_half, half) SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
SOFTMAX(softmax_bfloat, bfloat) SOFTMAX(softmax_bf16, bfloat)
#endif #endif

View File

@ -87,11 +87,11 @@ kernel void FN_NAME_STRIDED( \
} }
#define UNARY_OP(NAME) \ #define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided); UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_UNARY_OP(NAME) \ #define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
UNARY_OP(cos) UNARY_OP(cos)
@ -108,8 +108,8 @@ UNARY_OP(round)
UNARY_OP(gelu_erf) UNARY_OP(gelu_erf)
UNARY_OP(erf) UNARY_OP(erf)
UNARY_OP(tanh) UNARY_OP(tanh)
UNARY(id, float, copy_float, copy_float_strided) UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_half, copy_half_strided) UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided)
@ -129,5 +129,5 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(tanh)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif #endif

View File

@ -213,9 +213,9 @@ impl candle::CustomOp1 for SoftmaxLastDim {
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer();
let kernels = device.kernels(); let kernels = device.kernels();
let name = match storage.dtype() { let name = match storage.dtype() {
DType::F32 => "softmax_float", DType::F32 => "softmax_f32",
DType::F16 => "softmax_half", DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bfloat", DType::BF16 => "softmax_bf16",
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
}; };