mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00

* Improve reduce perf and add contiguous impl * Improve arg reduce and add contiguous impl * Improve softmax kernel. 33%-39% higher thrpt * fmt * Fixed all bugs. Improved code quality. Added tests. * Stash for debugging * Stash for debugging 2 * Fixing argmax bug and improve performance Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> * Fix test and add is_valid_simgroup_reduce_type trait * Online softmax. Improved threadgroup reduce. Tidying up a bit. * Remove redundant threadgroup_barrier from arg reduce * Mostly tidying up. Some improvements * Simplify indexed struct * tidying * Reuse operation operator instead of passing it in as a parameter * Fix how operators are applied to indexed<vec<T,N>> * Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce. * Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal. * Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math * Use constant for input instead of const device. Fix strided reduce. * Use contiguous reduce in tests * Rename finalize -> to_scalar * Support integer types max/min (switch with trait-inferred impl later) * Was worried I was skipping work -> shuffling the 1D test cases * Add build.rs to avoid metal kernel jit compile overhead * Improve build. Extract utils * Compile metal kernels for both macos and ios * Fixed over xmas and then forgot about it * Add calculate_reduce_threads util * Remove old reduce.metal * Improve f16/bf16 softmax precision by accumulating in f32 * Remove build.rs (for now) * Move softmax bench to candle-nn * Remove redundant thread calc util fn * Use uint over ushort for indices etc * Use fast exp in MDReduceOp * Remove nested metal define for softmax * Fix some clippy lint. --------- Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
159 lines
4.2 KiB
Rust
159 lines
4.2 KiB
Rust
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
use candle_core::{DType, Device, Tensor};
|
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
use half::{bf16, f16};
|
|
use std::time::Instant;
|
|
|
|
fn run_sum(a: &Tensor) {
|
|
a.sum_keepdim(2).unwrap();
|
|
}
|
|
fn run_arg_min(a: &Tensor) {
|
|
a.argmin_keepdim(2).unwrap();
|
|
}
|
|
|
|
fn criterion_benchmark(c: &mut Criterion) {
|
|
let handler = BenchDeviceHandler::new().unwrap();
|
|
let (lo, up) = (-1000.0f32, 1000.0f32);
|
|
for device in handler.devices {
|
|
run_reduce(c, &device, (lo, up), false);
|
|
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
|
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
|
|
|
run_arg_reduce(c, &device, (lo, up), false);
|
|
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
|
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
|
|
|
run_reduce(c, &device, (lo, up), true);
|
|
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
|
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
|
|
|
run_arg_reduce(c, &device, (lo, up), true);
|
|
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
|
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
|
}
|
|
}
|
|
|
|
fn run_reduce<T: candle_core::FloatDType>(
|
|
c: &mut Criterion,
|
|
device: &Device,
|
|
(lo, up): (T, T),
|
|
strided: bool,
|
|
) {
|
|
let b = 1;
|
|
let m = 1024;
|
|
let k = 1024;
|
|
|
|
let a = if strided {
|
|
Tensor::rand(lo, up, (b, m, k), &device)
|
|
.unwrap()
|
|
.transpose(0, 2)
|
|
.unwrap()
|
|
} else {
|
|
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
|
};
|
|
|
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
|
|
|
let name = match T::DTYPE {
|
|
DType::F32 => {
|
|
if strided {
|
|
"reduce_f32_strided"
|
|
} else {
|
|
"reduce_f32"
|
|
}
|
|
}
|
|
DType::F16 => {
|
|
if strided {
|
|
"reduce_f16_strided"
|
|
} else {
|
|
"reduce_f16"
|
|
}
|
|
}
|
|
DType::BF16 => {
|
|
if strided {
|
|
"reduce_bf16_strided"
|
|
} else {
|
|
"reduce_bf16"
|
|
}
|
|
}
|
|
_ => "unknown",
|
|
};
|
|
|
|
let mut group = c.benchmark_group(device.bench_name(name));
|
|
group.throughput(Throughput::Bytes(flops as u64));
|
|
group.bench_function("iter", move |b| {
|
|
b.iter_custom(|iters| {
|
|
let start = Instant::now();
|
|
for _i in 0..iters {
|
|
run_sum(black_box(&a));
|
|
}
|
|
device.sync().unwrap();
|
|
start.elapsed()
|
|
})
|
|
});
|
|
group.finish();
|
|
}
|
|
|
|
fn run_arg_reduce<T: candle_core::FloatDType>(
|
|
c: &mut Criterion,
|
|
device: &Device,
|
|
(lo, up): (T, T),
|
|
strided: bool,
|
|
) {
|
|
let b = 1;
|
|
let m = 1024;
|
|
let k = 1024;
|
|
|
|
let a = if strided {
|
|
Tensor::rand(lo, up, (b, m, k), &device)
|
|
.unwrap()
|
|
.transpose(0, 2)
|
|
.unwrap()
|
|
} else {
|
|
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
|
};
|
|
|
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
|
|
|
let name = match T::DTYPE {
|
|
DType::F32 => {
|
|
if strided {
|
|
"arg_reduce_f32_strided"
|
|
} else {
|
|
"arg_reduce_f32"
|
|
}
|
|
}
|
|
DType::F16 => {
|
|
if strided {
|
|
"arg_reduce_f16_strided"
|
|
} else {
|
|
"arg_reduce_f16"
|
|
}
|
|
}
|
|
DType::BF16 => {
|
|
if strided {
|
|
"arg_reduce_bf16_strided"
|
|
} else {
|
|
"arg_reduce_bf16"
|
|
}
|
|
}
|
|
_ => "unknown",
|
|
};
|
|
|
|
let mut group = c.benchmark_group(device.bench_name(name));
|
|
group.throughput(Throughput::Bytes(flops as u64));
|
|
group.bench_function("iter", move |b| {
|
|
b.iter_custom(|iters| {
|
|
let start = Instant::now();
|
|
for _i in 0..iters {
|
|
run_arg_min(black_box(&a));
|
|
}
|
|
device.sync().unwrap();
|
|
start.elapsed()
|
|
})
|
|
});
|
|
group.finish();
|
|
}
|
|
|
|
criterion_group!(benches, criterion_benchmark);
|