Files
ivarflakstad 7c2449f623 Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl

* Improve arg reduce and add contiguous impl

* Improve softmax kernel. 33%-39% higher thrpt

* fmt

* Fixed all bugs. Improved code quality. Added tests.

* Stash for debugging

* Stash for debugging 2

* Fixing argmax bug and improve performance

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>

* Fix test and add is_valid_simgroup_reduce_type trait

* Online softmax. Improved threadgroup reduce. Tidying up a bit.

* Remove redundant threadgroup_barrier from arg reduce

* Mostly tidying up. Some improvements

* Simplify indexed struct

* tidying

* Reuse operation operator instead of passing it in as a parameter

* Fix how operators are applied to indexed<vec<T,N>>

* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.

* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.

* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math

* Use constant for input instead of const device. Fix strided reduce.

* Use contiguous reduce in tests

* Rename finalize -> to_scalar

* Support integer types max/min (switch with trait-inferred impl later)

* Was worried I was skipping work -> shuffling the 1D test cases

* Add build.rs to avoid metal kernel jit compile overhead

* Improve build. Extract utils

* Compile metal kernels for both macos and ios

* Fixed over xmas and then forgot about it

* Add calculate_reduce_threads util

* Remove old reduce.metal

* Improve f16/bf16 softmax precision by accumulating in f32

* Remove build.rs (for now)

* Move softmax bench to candle-nn

* Remove redundant thread calc util fn

* Use uint over ushort for indices etc

* Use fast exp in MDReduceOp

* Remove nested metal define for softmax

* Fix some clippy lint.

---------

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-02-08 07:27:01 +01:00

50 lines
1.5 KiB
Rust

use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle::{DType, Device, Tensor};
use candle_nn::ops::softmax_last_dim;
use criterion::Throughput;
use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;
fn run(input: &Tensor) {
let _ = softmax_last_dim(&input).unwrap();
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
fn run_softmax_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let elements = B * M * K;
let input = Tensor::rand(-1000.0f32, 1000.0f32, (B, M, K), &device)
.unwrap()
.to_dtype(dtype)
.unwrap();
let flops = elements * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&input));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_softmax_benchmark(c, &d, DType::F32, "softmax_f32");
run_softmax_benchmark(c, &d, DType::BF16, "softmax_bf16");
run_softmax_benchmark(c, &d, DType::F16, "softmax_f16");
}
}
criterion_group!(benches, criterion_benchmark);