Metal: Improved reduce and softmax (#1819)

* Improve reduce perf and add contiguous impl

* Improve arg reduce and add contiguous impl

* Improve softmax kernel. 33%-39% higher thrpt

* fmt

* Fixed all bugs. Improved code quality. Added tests.

* Stash for debugging

* Stash for debugging 2

* Fixing argmax bug and improve performance

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>

* Fix test and add is_valid_simgroup_reduce_type trait

* Online softmax. Improved threadgroup reduce. Tidying up a bit.

* Remove redundant threadgroup_barrier from arg reduce

* Mostly tidying up. Some improvements

* Simplify indexed struct

* tidying

* Reuse operation operator instead of passing it in as a parameter

* Fix how operators are applied to indexed<vec<T,N>>

* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.

* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.

* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math

* Use constant for input instead of const device. Fix strided reduce.

* Use contiguous reduce in tests

* Rename finalize -> to_scalar

* Support integer types max/min (switch with trait-inferred impl later)

* Was worried I was skipping work -> shuffling the 1D test cases

* Add build.rs to avoid metal kernel jit compile overhead

* Improve build. Extract utils

* Compile metal kernels for both macos and ios

* Fixed over xmas and then forgot about it

* Add calculate_reduce_threads util

* Remove old reduce.metal

* Improve f16/bf16 softmax precision by accumulating in f32

* Remove build.rs (for now)

* Move softmax bench to candle-nn

* Remove redundant thread calc util fn

* Use uint over ushort for indices etc

* Use fast exp in MDReduceOp

* Remove nested metal define for softmax

* Fix some clippy lint.

---------

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
ivarflakstad
2025-02-08 07:27:01 +01:00
committed by GitHub
parent 0af3e428ec
commit 7c2449f623
12 changed files with 1521 additions and 357 deletions

View File

@ -1,10 +1,12 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,

View File

@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod reduce;
pub(crate) mod unary;
pub(crate) mod where_cond;

View File

@ -0,0 +1,158 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use half::{bf16, f16};
use std::time::Instant;
fn run_sum(a: &Tensor) {
a.sum_keepdim(2).unwrap();
}
fn run_arg_min(a: &Tensor) {
a.argmin_keepdim(2).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
let (lo, up) = (-1000.0f32, 1000.0f32);
for device in handler.devices {
run_reduce(c, &device, (lo, up), false);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_arg_reduce(c, &device, (lo, up), false);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_reduce(c, &device, (lo, up), true);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
run_arg_reduce(c, &device, (lo, up), true);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
}
}
fn run_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => {
if strided {
"reduce_f32_strided"
} else {
"reduce_f32"
}
}
DType::F16 => {
if strided {
"reduce_f16_strided"
} else {
"reduce_f16"
}
}
DType::BF16 => {
if strided {
"reduce_bf16_strided"
} else {
"reduce_bf16"
}
}
_ => "unknown",
};
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_sum(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn run_arg_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => {
if strided {
"arg_reduce_f32_strided"
} else {
"arg_reduce_f32"
}
}
DType::F16 => {
if strided {
"arg_reduce_f16_strided"
} else {
"arg_reduce_f16"
}
}
DType::BF16 => {
if strided {
"arg_reduce_bf16_strided"
} else {
"arg_reduce_bf16"
}
}
_ => "unknown",
};
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_arg_min(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
criterion_group!(benches, criterion_benchmark);

View File

@ -2,7 +2,6 @@ use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
@ -236,7 +235,7 @@ impl MetalDevice {
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
data.as_ptr().cast(),
size,
MTLResourceOptions::StorageModeManaged,
);

View File

@ -265,6 +265,7 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end.
@ -278,13 +279,72 @@ impl BackendStorage for MetalStorage {
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let reduction_shape = Shape::from(dims.clone());
if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
(k, dtype) => {
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
}
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
src_dims,
dst_el,
src,
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, dst_el, dtype));
}
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
@ -316,7 +376,7 @@ impl BackendStorage for MetalStorage {
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?

View File

@ -5,14 +5,12 @@ use metal::{
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
pub mod mlx_gemm;
pub mod sort;
pub mod utils;
pub use utils::BufferOffset;
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
pub use sort::{call_arg_sort, call_mlx_arg_sort};
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
@ -176,7 +174,7 @@ pub enum MetalKernelError {
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
#[error("Error while loading function: {0:?}")]
#[error("Error while loading function: {0}")]
LoadFunctionError(String),
#[error("Failed to create compute function")]
FailedToCreateComputeFunction,
@ -635,19 +633,31 @@ pub fn call_reduce_contiguous(
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
shape: &[usize],
out_length: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length = shape.iter().product::<usize>();
let num_dims = shape.len();
let work_per_threadgroup = length / out_length;
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, &input, output));
set_params!(
encoder,
(
length,
num_dims,
shape,
work_per_threadgroup,
&input,
output
)
);
let thread_group_count = MTLSize {
width: out_length as u64,
@ -657,9 +667,8 @@ pub fn call_reduce_contiguous(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
(elements_to_sum as u64).div_ceil(2),
)
.next_power_of_two();
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
);
let thread_group_size = MTLSize {
width,
@ -686,8 +695,9 @@ pub fn call_reduce_strided(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
let num_dims = shape.len();
let work_per_threadgroup = length / out_length;
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
@ -695,7 +705,15 @@ pub fn call_reduce_strided(
set_params!(
encoder,
(shape.len(), shape, strides, elements_to_sum, &input, output)
(
length,
num_dims,
shape,
strides,
work_per_threadgroup,
&input,
output
)
);
let thread_group_count = MTLSize {
@ -706,16 +724,14 @@ pub fn call_reduce_strided(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
@ -729,11 +745,13 @@ pub fn call_last_softmax(
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
elements: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let work_per_threadgroup = elements;
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
@ -741,29 +759,27 @@ pub fn call_last_softmax(
set_params!(
encoder,
(length, elements_to_sum, (input, input_offset), output)
(length, work_per_threadgroup, (input, input_offset), output)
);
let out_length = length / elements_to_sum;
let out_length = length / work_per_threadgroup;
let thread_group_count = MTLSize {
width: out_length as u64,
width: out_length as NSUInteger,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
use super::*;
use half::{bf16, f16};
use metal::MTLResourceOptions;
use metal::{Buffer, Device, MTLResourceOptions};
use rand::prelude::SliceRandom;
use rand::thread_rng;
use rand::Rng;
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
@ -860,7 +862,12 @@ fn cos_f16() {
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
fn run_reduce<T, U: Clone>(
v: &[T],
in_length: usize,
out_length: usize,
name: &'static str,
) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
@ -868,21 +875,24 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
let dims = vec![v.len()];
let strides = vec![1];
call_reduce_strided(
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
let shape = vec![in_length];
match call_reduce_contiguous(
&device,
command_buffer,
&kernels,
name,
&dims,
&strides,
&shape,
out_length,
BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
) {
Ok(_) => {}
Err(e) => {
println!("{e}");
panic!();
}
}
command_buffer.commit();
command_buffer.wait_until_completed();
@ -914,22 +924,187 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
read_to_vec(&output, v.len())
}
#[test]
fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
const fn create_array<const N: usize>() -> [f32; N] {
let mut array: [f32; N] = [0.0; N];
let mut i = 1;
while i <= N {
array[i - 1] = i as f32;
i += 1;
}
array
}
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]);
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
let mut sum = 0;
let mut results: [f32; D] = [0.0; D];
let mut i = 1;
let mut j = 1;
while i <= N {
sum += i;
i += 1;
if i > j * N / D {
results[j - 1] = sum as f32;
j += 1;
sum = 0;
}
}
results
}
const fn correct_max<const N: usize, const D: usize>() -> [f32; D] {
let mut results: [f32; D] = [0.0; D];
let mut i = 1;
let mut j = 1;
while i <= N {
i += 1;
if i > j * (N / D) {
results[j - 1] = (i - 1) as f32;
j += 1;
}
}
results
}
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
let mut max = 0.0;
let mut max_index: u32 = 0;
let mut results: [u32; D] = [0; D];
let mut i = 0;
let mut j = 1;
while i <= N {
if i >= (j * N / D) {
results[j - 1] = max_index;
max = 0.0;
max_index = 0;
j += 1;
}
if i == N {
break;
}
if arr[i] > max {
max = arr[i];
max_index = i as u32;
}
i += 1;
}
results
}
fn reduce_sum_case<const N: usize, const D: usize>() {
let mut v = create_array::<N>();
if D == 1 {
// Hardens 1-dimensional test cases
v.shuffle(&mut thread_rng());
}
let results = run_reduce(&v, N, D, "fast_sum_f32");
assert_eq!(approx(results, 4), correct_sum::<N, D>());
}
fn reduce_max_case<const N: usize, const D: usize>() {
let mut v = create_array::<N>();
if D == 1 {
// Hardens 1-dimensional test cases
v.shuffle(&mut thread_rng());
}
let results = run_reduce(&v, N, D, "fast_max_f32");
assert_eq!(approx(results, 4), correct_max::<N, D>());
}
fn reduce_argmax_case<const N: usize, const D: usize>() {
let mut v = create_array::<N>();
if D == 1 {
// Hardens 1-dimensional test cases
v.shuffle(&mut thread_rng());
}
let results: Vec<u32> = run_reduce(&v, N, D, "fast_argmax_f32");
assert_eq!(results, correct_argmax::<N, D>(v));
}
#[test]
fn reduce_sum1() {
reduce_sum_case::<9, 1>();
reduce_sum_case::<6, 1>();
reduce_sum_case::<10, 1>();
reduce_sum_case::<64, 1>();
reduce_sum_case::<128, 1>();
reduce_sum_case::<256, 1>();
reduce_sum_case::<512, 1>();
reduce_sum_case::<1024, 1>();
reduce_sum_case::<2048, 1>();
reduce_sum_case::<4096, 1>();
}
#[test]
fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
reduce_sum_case::<6, 2>();
reduce_sum_case::<10, 2>();
reduce_sum_case::<64, 2>();
reduce_sum_case::<128, 2>();
reduce_sum_case::<256, 2>();
reduce_sum_case::<512, 2>();
reduce_sum_case::<1024, 2>();
reduce_sum_case::<2048, 2>();
reduce_sum_case::<4096, 2>();
}
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
#[test]
fn reduce_max() {
reduce_max_case::<6, 1>();
reduce_max_case::<9, 1>();
reduce_max_case::<10, 1>();
reduce_max_case::<64, 1>();
reduce_max_case::<128, 1>();
reduce_max_case::<256, 1>();
reduce_max_case::<512, 1>();
reduce_max_case::<1024, 1>();
reduce_max_case::<2048, 1>();
reduce_max_case::<4096, 1>();
reduce_max_case::<6, 2>();
reduce_max_case::<10, 2>();
reduce_max_case::<64, 2>();
reduce_max_case::<128, 2>();
reduce_max_case::<256, 2>();
reduce_max_case::<512, 2>();
reduce_max_case::<1024, 2>();
reduce_max_case::<2048, 2>();
reduce_max_case::<4096, 2>();
reduce_max_case::<6, 3>();
reduce_max_case::<10, 3>();
reduce_max_case::<64, 3>();
reduce_max_case::<128, 3>();
reduce_max_case::<256, 3>();
reduce_max_case::<512, 3>();
reduce_max_case::<1024, 3>();
reduce_max_case::<2048, 3>();
reduce_max_case::<4096, 3>();
}
#[test]
fn reduce_argmax() {
reduce_argmax_case::<6, 1>();
reduce_argmax_case::<9, 1>();
reduce_argmax_case::<10, 1>();
reduce_argmax_case::<64, 1>();
reduce_argmax_case::<128, 1>();
reduce_argmax_case::<256, 1>();
reduce_argmax_case::<512, 1>();
reduce_argmax_case::<1024, 1>();
reduce_argmax_case::<2048, 1>();
}
#[test]
fn reduce_argmax2() {
reduce_argmax_case::<6, 2>();
reduce_argmax_case::<10, 2>();
reduce_argmax_case::<64, 2>();
reduce_argmax_case::<128, 2>();
reduce_argmax_case::<256, 2>();
reduce_argmax_case::<512, 2>();
reduce_argmax_case::<1024, 2>();
reduce_argmax_case::<2048, 2>();
reduce_argmax_case::<4096, 2>();
}
#[test]
@ -983,7 +1158,7 @@ fn softmax() {
let results = run_softmax(&v, last_dim, "softmax_f16");
assert_eq!(
approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2332, 0.6338]
);
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]

View File

@ -0,0 +1,47 @@
#pragma once
#include <metal_stdlib>
using namespace metal;
METAL_FUNC uint nonzero(uint n) {
return n == 0 ? 1 : n;
}
template<uint N>
constexpr uint nonzero() {
return N == 0 ? 1 : N;
}
template<typename T>
constexpr ushort granularity() {
return nonzero<vec_elements<T>::value>();
}
METAL_FUNC uint next_p2(uint x) {
return 1 << (32 - clz(x - 1));
}
METAL_FUNC uint prev_p2(uint x) {
return 1 << (31 - clz(x));
}
constant uint MAX_SHARED_MEM = 32767;
template<typename T>
METAL_FUNC uint max_shared_mem(uint n) {
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
}
METAL_FUNC uint get_strided_index(
uint idx,
constant const uint &num_dims,
constant const size_t *dims,
constant const size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

View File

@ -1,4 +1,8 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches);
criterion_main!(
benchmarks::softmax::benches,
benchmarks::layer_norm::benches,
benchmarks::conv::benches
);

View File

@ -1,5 +1,6 @@
pub(crate) mod conv;
pub(crate) mod layer_norm;
pub(crate) mod softmax;
use candle::{Device, Result};

View File

@ -0,0 +1,49 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle::{DType, Device, Tensor};
use candle_nn::ops::softmax_last_dim;
use criterion::Throughput;
use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;
fn run(input: &Tensor) {
let _ = softmax_last_dim(&input).unwrap();
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
fn run_softmax_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let elements = B * M * K;
let input = Tensor::rand(-1000.0f32, 1000.0f32, (B, M, K), &device)
.unwrap()
.to_dtype(dtype)
.unwrap();
let flops = elements * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&input));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_softmax_benchmark(c, &d, DType::F32, "softmax_f32");
run_softmax_benchmark(c, &d, DType::BF16, "softmax_bf16");
run_softmax_benchmark(c, &d, DType::F16, "softmax_f16");
}
}
criterion_group!(benches, criterion_benchmark);