mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Renamed all kernel names.
This commit is contained in:
@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_float",
|
DType::F32 => "affine_f32",
|
||||||
DType::F16 => "affine_half",
|
DType::F16 => "affine_f16",
|
||||||
dtype => crate::bail!("Affine {dtype:?}"),
|
dtype => crate::bail!("Affine {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine(
|
candle_metal_kernels::call_affine(
|
||||||
@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_float_strided",
|
DType::F32 => "affine_f32_strided",
|
||||||
DType::F16 => "affine_half_strided",
|
DType::F16 => "affine_f16_strided",
|
||||||
dtype => crate::bail!("Affine {dtype:?}"),
|
dtype => crate::bail!("Affine {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_affine_strided(
|
candle_metal_kernels::call_affine_strided(
|
||||||
@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "powf_float",
|
DType::F32 => "powf_f32",
|
||||||
DType::F16 => "powf_half",
|
DType::F16 => "powf_f16",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Powf {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_powf(
|
candle_metal_kernels::call_powf(
|
||||||
@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "powf_float_strided",
|
DType::F32 => "powf_f32_strided",
|
||||||
DType::F16 => "powf_half_strided",
|
DType::F16 => "powf_f16_strided",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Powf {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_powf_strided(
|
candle_metal_kernels::call_powf_strided(
|
||||||
@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "elu_float",
|
DType::F32 => "elu_f32",
|
||||||
DType::F16 => "elu_half",
|
DType::F16 => "elu_f16",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Powf {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_elu(
|
candle_metal_kernels::call_elu(
|
||||||
@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "elu_float_strided",
|
DType::F32 => "elu_f32_strided",
|
||||||
DType::F16 => "elu_half_strided",
|
DType::F16 => "elu_f16_strided",
|
||||||
dtype => crate::bail!("Powf {dtype:?}"),
|
dtype => crate::bail!("Powf {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_elu_strided(
|
candle_metal_kernels::call_elu_strided(
|
||||||
@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
// The reduction loop requires the shared array to be properly initialized and for
|
// The reduction loop requires the shared array to be properly initialized and for
|
||||||
// this we want the number of threads to be a power of two.
|
// this we want the number of threads to be a power of two.
|
||||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||||
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
|
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||||
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||||
_ => crate::bail!("Reduce op for non float"),
|
_ => crate::bail!("Reduce op for non float"),
|
||||||
};
|
};
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
|
@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \
|
|||||||
} \
|
} \
|
||||||
|
|
||||||
|
|
||||||
AFFINE(affine_float, float)
|
AFFINE(affine_f32, float)
|
||||||
AFFINE(affine_half, half)
|
AFFINE(affine_f16, half)
|
||||||
POWF(powf_float, float)
|
POWF(powf_f32, float)
|
||||||
POWF(powf_half, half)
|
POWF(powf_f16, half)
|
||||||
ELU(elu_float, float)
|
ELU(elu_f32, float)
|
||||||
ELU(elu_half, half)
|
ELU(elu_f16, half)
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
AFFINE(affine_bfloat, bfloat);
|
AFFINE(affine_bf16, bfloat);
|
||||||
POWF(powf_bfloat, bfloat);
|
POWF(powf_bf16, bfloat);
|
||||||
ELU(elu_bfloat, bfloat);
|
ELU(elu_bf16, bfloat);
|
||||||
#endif
|
#endif
|
||||||
|
@ -52,11 +52,11 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_OP(FN, NAME) \
|
#define BINARY_OP(FN, NAME) \
|
||||||
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
|
||||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
||||||
|
|
||||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||||
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||||
|
|
||||||
|
|
||||||
BINARY_OP(x + y, add)
|
BINARY_OP(x + y, add)
|
||||||
|
@ -125,16 +125,16 @@ macro_rules! ops{
|
|||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
|
||||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
|
||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
pub mod copy {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
pub const FLOAT: Kernel = Kernel("copy_float");
|
pub const FLOAT: Kernel = Kernel("copy_f32");
|
||||||
pub const HALF: Kernel = Kernel("copy_half");
|
pub const HALF: Kernel = Kernel("copy_f16");
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat");
|
pub const BFLOAT: Kernel = Kernel("copy_bf16");
|
||||||
pub const U32: Kernel = Kernel("copy_u32");
|
pub const U32: Kernel = Kernel("copy_u32");
|
||||||
pub const U8: Kernel = Kernel("copy_u8");
|
pub const U8: Kernel = Kernel("copy_u8");
|
||||||
}
|
}
|
||||||
@ -145,16 +145,16 @@ macro_rules! ops{
|
|||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
|
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
|
||||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
|
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
|
||||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
|
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
|
||||||
}
|
}
|
||||||
)+
|
)+
|
||||||
pub mod copy {
|
pub mod copy {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
pub const FLOAT: Kernel = Kernel("copy_float_strided");
|
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
|
||||||
pub const HALF: Kernel = Kernel("copy_half_strided");
|
pub const HALF: Kernel = Kernel("copy_f16_strided");
|
||||||
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
|
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
|
||||||
pub const U32: Kernel = Kernel("copy_u32_strided");
|
pub const U32: Kernel = Kernel("copy_u32_strided");
|
||||||
pub const U8: Kernel = Kernel("copy_u8_strided");
|
pub const U8: Kernel = Kernel("copy_u8_strided");
|
||||||
}
|
}
|
||||||
|
@ -71,9 +71,9 @@ kernel void NAME( \
|
|||||||
} \
|
} \
|
||||||
|
|
||||||
|
|
||||||
REDUCE(x + y, fast_sum_float, float)
|
REDUCE(x + y, fast_sum_f32, float)
|
||||||
REDUCE(x * y, fast_mul_float, float)
|
REDUCE(x * y, fast_mul_f32, float)
|
||||||
REDUCE(max(x, y), fast_max_float, float)
|
REDUCE(max(x, y), fast_max_f32, float)
|
||||||
|
|
||||||
#define SOFTMAX(NAME, T) \
|
#define SOFTMAX(NAME, T) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
@ -142,8 +142,8 @@ kernel void NAME(
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
SOFTMAX(softmax_float, float)
|
SOFTMAX(softmax_f32, float)
|
||||||
SOFTMAX(softmax_half, half)
|
SOFTMAX(softmax_f16, half)
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
SOFTMAX(softmax_bfloat, bfloat)
|
SOFTMAX(softmax_bf16, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
@ -87,11 +87,11 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define UNARY_OP(NAME) \
|
#define UNARY_OP(NAME) \
|
||||||
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
|
||||||
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
|
||||||
|
|
||||||
#define BFLOAT_UNARY_OP(NAME) \
|
#define BFLOAT_UNARY_OP(NAME) \
|
||||||
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||||
|
|
||||||
|
|
||||||
UNARY_OP(cos)
|
UNARY_OP(cos)
|
||||||
@ -108,8 +108,8 @@ UNARY_OP(round)
|
|||||||
UNARY_OP(gelu_erf)
|
UNARY_OP(gelu_erf)
|
||||||
UNARY_OP(erf)
|
UNARY_OP(erf)
|
||||||
UNARY_OP(tanh)
|
UNARY_OP(tanh)
|
||||||
UNARY(id, float, copy_float, copy_float_strided)
|
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||||
UNARY(id, half, copy_half, copy_half_strided)
|
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||||
|
|
||||||
@ -129,5 +129,5 @@ BFLOAT_UNARY_OP(gelu_erf)
|
|||||||
BFLOAT_UNARY_OP(erf)
|
BFLOAT_UNARY_OP(erf)
|
||||||
BFLOAT_UNARY_OP(tanh)
|
BFLOAT_UNARY_OP(tanh)
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||||
#endif
|
#endif
|
||||||
|
@ -213,9 +213,9 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
let kernels = device.kernels();
|
let kernels = device.kernels();
|
||||||
let name = match storage.dtype() {
|
let name = match storage.dtype() {
|
||||||
DType::F32 => "softmax_float",
|
DType::F32 => "softmax_f32",
|
||||||
DType::F16 => "softmax_half",
|
DType::F16 => "softmax_f16",
|
||||||
DType::BF16 => "softmax_bfloat",
|
DType::BF16 => "softmax_bf16",
|
||||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user