mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl * Improve arg reduce and add contiguous impl * Improve softmax kernel. 33%-39% higher thrpt * fmt * Fixed all bugs. Improved code quality. Added tests. * Stash for debugging * Stash for debugging 2 * Fixing argmax bug and improve performance Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> * Fix test and add is_valid_simgroup_reduce_type trait * Online softmax. Improved threadgroup reduce. Tidying up a bit. * Remove redundant threadgroup_barrier from arg reduce * Mostly tidying up. Some improvements * Simplify indexed struct * tidying * Reuse operation operator instead of passing it in as a parameter * Fix how operators are applied to indexed<vec<T,N>> * Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce. * Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal. * Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math * Use constant for input instead of const device. Fix strided reduce. * Use contiguous reduce in tests * Rename finalize -> to_scalar * Support integer types max/min (switch with trait-inferred impl later) * Was worried I was skipping work -> shuffling the 1D test cases * Add build.rs to avoid metal kernel jit compile overhead * Improve build. Extract utils * Compile metal kernels for both macos and ios * Fixed over xmas and then forgot about it * Add calculate_reduce_threads util * Remove old reduce.metal * Improve f16/bf16 softmax precision by accumulating in f32 * Remove build.rs (for now) * Move softmax bench to candle-nn * Remove redundant thread calc util fn * Use uint over ushort for indices etc * Use fast exp in MDReduceOp * Remove nested metal define for softmax * Fix some clippy lint. --------- Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,10 +1,12 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
|
@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod reduce;
|
||||
pub(crate) mod unary;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
|
158
candle-core/benches/benchmarks/reduce.rs
Normal file
158
candle-core/benches/benchmarks/reduce.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use half::{bf16, f16};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_sum(a: &Tensor) {
|
||||
a.sum_keepdim(2).unwrap();
|
||||
}
|
||||
fn run_arg_min(a: &Tensor) {
|
||||
a.argmin_keepdim(2).unwrap();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||
for device in handler.devices {
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_sum(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_arg_min(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -2,7 +2,6 @@ use crate::{DType, Result};
|
||||
use candle_metal_kernels::Kernels;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
|
||||
@ -236,7 +235,7 @@ impl MetalDevice {
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let new_buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const c_void,
|
||||
data.as_ptr().cast(),
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
|
@ -265,6 +265,7 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
@ -278,13 +279,72 @@ impl BackendStorage for MetalStorage {
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let reduction_shape = Shape::from(dims.clone());
|
||||
|
||||
if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) {
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
(k, dtype) => {
|
||||
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
src_dims,
|
||||
dst_el,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
return Ok(Self::new(buffer, device, dst_el, dtype));
|
||||
}
|
||||
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
@ -316,7 +376,7 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
|
Reference in New Issue
Block a user