mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Improve reduce perf and add contiguous impl
This commit is contained in:
@ -491,6 +491,7 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
@ -504,13 +505,72 @@ impl BackendStorage for MetalStorage {
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
if layout.is_contiguous() {
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
//(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
//(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||
//(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
//(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||
//(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
//(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||
//(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
//(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||
//(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
//(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||
//(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
//(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
//(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
_ => ("fall back to strided impl", false, false)
|
||||
};
|
||||
|
||||
if name != "fall back to strided impl" {
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
|
||||
|
||||
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.shape().elem_count(),
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
return Ok(Self::new(buffer, device, self.dtype));
|
||||
}
|
||||
}
|
||||
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
|
Reference in New Issue
Block a user