Finish reduce kernels.

This commit is contained in:
Nicolas Patry
2023-12-17 19:07:00 +01:00
parent 6bc92e63cb
commit 972903021c
6 changed files with 258 additions and 39 deletions

View File

@ -482,20 +482,9 @@ impl BackendStorage for MetalStorage {
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
if sum_dims.len() != 1 {
crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet.");
}
if sum_dims[0] != layout.shape().rank() - 1 {
crate::bail!("Non last dim reduce op {op:?} not implemented yet");
}
if layout.stride()[sum_dims[0]] != 1 {
crate::bail!("Non contiguous reduce op {op:?} not implemented yet");
}
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
@ -515,28 +504,41 @@ impl BackendStorage for MetalStorage {
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
_ => crate::bail!("Reduce op for non float"),
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
(k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
if dtype == DType::U32 {
crate::bail!("reduce op {name} is not implemented yet.");
}
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous(
candle_metal_kernels::call_reduce_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
src_el,
&dims,
&stride,
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
@ -730,7 +732,7 @@ impl BackendStorage for MetalStorage {
("sub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
@ -751,11 +753,15 @@ impl BackendStorage for MetalStorage {
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
("bminimum", DType::F32) => strided::min::FLOAT,
("bmaximum", DType::F32) => strided::max::FLOAT,
("badd", DType::F16) => strided::add::HALF,
("bsub", DType::F16) => strided::sub::HALF,
("bmul", DType::F16) => strided::mul::HALF,
("bdiv", DType::F16) => strided::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
("bminimum", DType::F16) => strided::min::HALF,
("bmaximum", DType::F16) => strided::max::HALF,
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
candle_metal_kernels::call_binary_strided(
&device.device,