Renamed all kernel names.

This commit is contained in:
Nicolas Patry
2023-12-15 11:24:47 +01:00
parent 34d83377f6
commit 26540641c1
7 changed files with 56 additions and 56 deletions

View File

@ -314,8 +314,8 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype {
DType::F32 => "affine_float",
DType::F16 => "affine_half",
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
dtype => crate::bail!("Affine {dtype:?}"),
};
candle_metal_kernels::call_affine(
@ -332,8 +332,8 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "affine_float_strided",
DType::F16 => "affine_half_strided",
DType::F32 => "affine_f32_strided",
DType::F16 => "affine_f16_strided",
dtype => crate::bail!("Affine {dtype:?}"),
};
candle_metal_kernels::call_affine_strided(
@ -365,8 +365,8 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype {
DType::F32 => "powf_float",
DType::F16 => "powf_half",
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
dtype => crate::bail!("Powf {dtype:?}"),
};
candle_metal_kernels::call_powf(
@ -382,8 +382,8 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "powf_float_strided",
DType::F16 => "powf_half_strided",
DType::F32 => "powf_f32_strided",
DType::F16 => "powf_f16_strided",
dtype => crate::bail!("Powf {dtype:?}"),
};
candle_metal_kernels::call_powf_strided(
@ -414,8 +414,8 @@ impl BackendStorage for MetalStorage {
let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype {
DType::F32 => "elu_float",
DType::F16 => "elu_half",
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
dtype => crate::bail!("Powf {dtype:?}"),
};
candle_metal_kernels::call_elu(
@ -431,8 +431,8 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "elu_float_strided",
DType::F16 => "elu_half_strided",
DType::F32 => "elu_f32_strided",
DType::F16 => "elu_f16_strided",
dtype => crate::bail!("Powf {dtype:?}"),
};
candle_metal_kernels::call_elu_strided(
@ -483,11 +483,11 @@ impl BackendStorage for MetalStorage {
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
_ => crate::bail!("Reduce op for non float"),
};
if check_empty && layout.shape().elem_count() == 0 {