Improve reduce perf and add contiguous impl

This commit is contained in:
Ivar Flakstad
2024-01-21 17:32:21 +01:00
parent 88945f2c22
commit d5902840e0
7 changed files with 409 additions and 96 deletions

View File

@ -491,6 +491,7 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end.
@ -504,13 +505,72 @@ impl BackendStorage for MetalStorage {
stride.push(src_stride[dim_idx]);
}
}
if layout.is_contiguous() {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
//(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
//(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
//(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
//(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
//(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
//(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
//(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
//(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
//(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
//(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
//(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
//(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
//(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
_ => ("fall back to strided impl", false, false)
};
if name != "fall back to strided impl" {
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.shape().elem_count(),
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, self.dtype));
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),