mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Renamed all kernel names.
This commit is contained in:
@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float",
|
||||
DType::F16 => "affine_half",
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float_strided",
|
||||
DType::F16 => "affine_half_strided",
|
||||
DType::F32 => "affine_f32_strided",
|
||||
DType::F16 => "affine_f16_strided",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_float",
|
||||
DType::F16 => "powf_half",
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_powf(
|
||||
@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_float_strided",
|
||||
DType::F16 => "powf_half_strided",
|
||||
DType::F32 => "powf_f32_strided",
|
||||
DType::F16 => "powf_f16_strided",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_powf_strided(
|
||||
@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_float",
|
||||
DType::F16 => "elu_half",
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_elu(
|
||||
@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_float_strided",
|
||||
DType::F16 => "elu_half_strided",
|
||||
DType::F32 => "elu_f32_strided",
|
||||
DType::F16 => "elu_f16_strided",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_elu_strided(
|
||||
@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage {
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
_ => crate::bail!("Reduce op for non float"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
|
@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \
|
||||
} \
|
||||
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
POWF(powf_float, float)
|
||||
POWF(powf_half, half)
|
||||
ELU(elu_float, float)
|
||||
ELU(elu_half, half)
|
||||
AFFINE(affine_f32, float)
|
||||
AFFINE(affine_f16, half)
|
||||
POWF(powf_f32, float)
|
||||
POWF(powf_f16, half)
|
||||
ELU(elu_f32, float)
|
||||
ELU(elu_f16, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
POWF(powf_bfloat, bfloat);
|
||||
ELU(elu_bfloat, bfloat);
|
||||
AFFINE(affine_bf16, bfloat);
|
||||
POWF(powf_bf16, bfloat);
|
||||
ELU(elu_bf16, bfloat);
|
||||
#endif
|
||||
|
@ -52,11 +52,11 @@ kernel void FN_NAME_STRIDED( \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
|
@ -125,16 +125,16 @@ macro_rules! ops{
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float");
|
||||
pub const HALF: Kernel = Kernel("copy_half");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
||||
pub const FLOAT: Kernel = Kernel("copy_f32");
|
||||
pub const HALF: Kernel = Kernel("copy_f16");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
||||
pub const U32: Kernel = Kernel("copy_u32");
|
||||
pub const U8: Kernel = Kernel("copy_u8");
|
||||
}
|
||||
@ -145,16 +145,16 @@ macro_rules! ops{
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_float_strided");
|
||||
pub const HALF: Kernel = Kernel("copy_half_strided");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
||||
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
||||
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
||||
}
|
||||
|
@ -71,9 +71,9 @@ kernel void NAME( \
|
||||
} \
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_float, float)
|
||||
REDUCE(x * y, fast_mul_float, float)
|
||||
REDUCE(max(x, y), fast_max_float, float)
|
||||
REDUCE(x + y, fast_sum_f32, float)
|
||||
REDUCE(x * y, fast_mul_f32, float)
|
||||
REDUCE(max(x, y), fast_max_f32, float)
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
@ -142,8 +142,8 @@ kernel void NAME(
|
||||
} \
|
||||
} \
|
||||
|
||||
SOFTMAX(softmax_float, float)
|
||||
SOFTMAX(softmax_half, half)
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
SOFTMAX(softmax_bfloat, bfloat)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
||||
|
@ -87,11 +87,11 @@ kernel void FN_NAME_STRIDED( \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||
UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
|
||||
UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
|
||||
UNARY_OP(cos)
|
||||
@ -108,8 +108,8 @@ UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
@ -129,5 +129,5 @@ BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
#endif
|
||||
|
@ -213,9 +213,9 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
let command_buffer = device.command_buffer();
|
||||
let kernels = device.kernels();
|
||||
let name = match storage.dtype() {
|
||||
DType::F32 => "softmax_float",
|
||||
DType::F16 => "softmax_half",
|
||||
DType::BF16 => "softmax_bfloat",
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user