mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00

* Improve reduce perf and add contiguous impl * Improve arg reduce and add contiguous impl * Improve softmax kernel. 33%-39% higher thrpt * fmt * Fixed all bugs. Improved code quality. Added tests. * Stash for debugging * Stash for debugging 2 * Fixing argmax bug and improve performance Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> * Fix test and add is_valid_simgroup_reduce_type trait * Online softmax. Improved threadgroup reduce. Tidying up a bit. * Remove redundant threadgroup_barrier from arg reduce * Mostly tidying up. Some improvements * Simplify indexed struct * tidying * Reuse operation operator instead of passing it in as a parameter * Fix how operators are applied to indexed<vec<T,N>> * Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce. * Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal. * Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math * Use constant for input instead of const device. Fix strided reduce. * Use contiguous reduce in tests * Rename finalize -> to_scalar * Support integer types max/min (switch with trait-inferred impl later) * Was worried I was skipping work -> shuffling the 1D test cases * Add build.rs to avoid metal kernel jit compile overhead * Improve build. Extract utils * Compile metal kernels for both macos and ios * Fixed over xmas and then forgot about it * Add calculate_reduce_threads util * Remove old reduce.metal * Improve f16/bf16 softmax precision by accumulating in f32 * Remove build.rs (for now) * Move softmax bench to candle-nn * Remove redundant thread calc util fn * Use uint over ushort for indices etc * Use fast exp in MDReduceOp * Remove nested metal define for softmax * Fix some clippy lint. --------- Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
9 lines
169 B
Rust
9 lines
169 B
Rust
mod benchmarks;
|
|
|
|
use criterion::criterion_main;
|
|
criterion_main!(
|
|
benchmarks::softmax::benches,
|
|
benchmarks::layer_norm::benches,
|
|
benchmarks::conv::benches
|
|
);
|