mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Renamed all kernel names.
This commit is contained in:
@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float",
|
||||
DType::F16 => "affine_half",
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine(
|
||||
@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_float_strided",
|
||||
DType::F16 => "affine_half_strided",
|
||||
DType::F32 => "affine_f32_strided",
|
||||
DType::F16 => "affine_f16_strided",
|
||||
dtype => crate::bail!("Affine {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_affine_strided(
|
||||
@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_float",
|
||||
DType::F16 => "powf_half",
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_powf(
|
||||
@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_float_strided",
|
||||
DType::F16 => "powf_half_strided",
|
||||
DType::F32 => "powf_f32_strided",
|
||||
DType::F16 => "powf_f16_strided",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_powf_strided(
|
||||
@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_float",
|
||||
DType::F16 => "elu_half",
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_elu(
|
||||
@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_float_strided",
|
||||
DType::F16 => "elu_half_strided",
|
||||
DType::F32 => "elu_f32_strided",
|
||||
DType::F16 => "elu_f16_strided",
|
||||
dtype => crate::bail!("Powf {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_elu_strided(
|
||||
@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage {
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
_ => crate::bail!("Reduce op for non float"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
|
Reference in New Issue
Block a user