mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl * Improve arg reduce and add contiguous impl * Improve softmax kernel. 33%-39% higher thrpt * fmt * Fixed all bugs. Improved code quality. Added tests. * Stash for debugging * Stash for debugging 2 * Fixing argmax bug and improve performance Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> * Fix test and add is_valid_simgroup_reduce_type trait * Online softmax. Improved threadgroup reduce. Tidying up a bit. * Remove redundant threadgroup_barrier from arg reduce * Mostly tidying up. Some improvements * Simplify indexed struct * tidying * Reuse operation operator instead of passing it in as a parameter * Fix how operators are applied to indexed<vec<T,N>> * Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce. * Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal. * Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math * Use constant for input instead of const device. Fix strided reduce. * Use contiguous reduce in tests * Rename finalize -> to_scalar * Support integer types max/min (switch with trait-inferred impl later) * Was worried I was skipping work -> shuffling the 1D test cases * Add build.rs to avoid metal kernel jit compile overhead * Improve build. Extract utils * Compile metal kernels for both macos and ios * Fixed over xmas and then forgot about it * Add calculate_reduce_threads util * Remove old reduce.metal * Improve f16/bf16 softmax precision by accumulating in f32 * Remove build.rs (for now) * Move softmax bench to candle-nn * Remove redundant thread calc util fn * Use uint over ushort for indices etc * Use fast exp in MDReduceOp * Remove nested metal define for softmax * Fix some clippy lint. --------- Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com> Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,10 +1,12 @@
|
|||||||
mod benchmarks;
|
mod benchmarks;
|
||||||
|
|
||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
|
|
||||||
criterion_main!(
|
criterion_main!(
|
||||||
benchmarks::affine::benches,
|
benchmarks::affine::benches,
|
||||||
benchmarks::matmul::benches,
|
benchmarks::matmul::benches,
|
||||||
benchmarks::random::benches,
|
benchmarks::random::benches,
|
||||||
|
benchmarks::reduce::benches,
|
||||||
benchmarks::where_cond::benches,
|
benchmarks::where_cond::benches,
|
||||||
benchmarks::conv_transpose2d::benches,
|
benchmarks::conv_transpose2d::benches,
|
||||||
benchmarks::qmatmul::benches,
|
benchmarks::qmatmul::benches,
|
||||||
|
@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d;
|
|||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
pub(crate) mod qmatmul;
|
pub(crate) mod qmatmul;
|
||||||
pub(crate) mod random;
|
pub(crate) mod random;
|
||||||
|
pub(crate) mod reduce;
|
||||||
pub(crate) mod unary;
|
pub(crate) mod unary;
|
||||||
pub(crate) mod where_cond;
|
pub(crate) mod where_cond;
|
||||||
|
|
||||||
|
158
candle-core/benches/benchmarks/reduce.rs
Normal file
158
candle-core/benches/benchmarks/reduce.rs
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle_core::{DType, Device, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run_sum(a: &Tensor) {
|
||||||
|
a.sum_keepdim(2).unwrap();
|
||||||
|
}
|
||||||
|
fn run_arg_min(a: &Tensor) {
|
||||||
|
a.argmin_keepdim(2).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
|
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||||
|
for device in handler.devices {
|
||||||
|
run_reduce(c, &device, (lo, up), false);
|
||||||
|
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||||
|
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||||
|
|
||||||
|
run_arg_reduce(c, &device, (lo, up), false);
|
||||||
|
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||||
|
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||||
|
|
||||||
|
run_reduce(c, &device, (lo, up), true);
|
||||||
|
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||||
|
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||||
|
|
||||||
|
run_arg_reduce(c, &device, (lo, up), true);
|
||||||
|
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||||
|
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_reduce<T: candle_core::FloatDType>(
|
||||||
|
c: &mut Criterion,
|
||||||
|
device: &Device,
|
||||||
|
(lo, up): (T, T),
|
||||||
|
strided: bool,
|
||||||
|
) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let a = if strided {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device)
|
||||||
|
.unwrap()
|
||||||
|
.transpose(0, 2)
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||||
|
|
||||||
|
let name = match T::DTYPE {
|
||||||
|
DType::F32 => {
|
||||||
|
if strided {
|
||||||
|
"reduce_f32_strided"
|
||||||
|
} else {
|
||||||
|
"reduce_f32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
if strided {
|
||||||
|
"reduce_f16_strided"
|
||||||
|
} else {
|
||||||
|
"reduce_f16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
if strided {
|
||||||
|
"reduce_bf16_strided"
|
||||||
|
} else {
|
||||||
|
"reduce_bf16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => "unknown",
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run_sum(black_box(&a));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||||
|
c: &mut Criterion,
|
||||||
|
device: &Device,
|
||||||
|
(lo, up): (T, T),
|
||||||
|
strided: bool,
|
||||||
|
) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 1024;
|
||||||
|
let k = 1024;
|
||||||
|
|
||||||
|
let a = if strided {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device)
|
||||||
|
.unwrap()
|
||||||
|
.transpose(0, 2)
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||||
|
|
||||||
|
let name = match T::DTYPE {
|
||||||
|
DType::F32 => {
|
||||||
|
if strided {
|
||||||
|
"arg_reduce_f32_strided"
|
||||||
|
} else {
|
||||||
|
"arg_reduce_f32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
if strided {
|
||||||
|
"arg_reduce_f16_strided"
|
||||||
|
} else {
|
||||||
|
"arg_reduce_f16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
if strided {
|
||||||
|
"arg_reduce_bf16_strided"
|
||||||
|
} else {
|
||||||
|
"arg_reduce_bf16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => "unknown",
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run_arg_min(black_box(&a));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -2,7 +2,6 @@ use crate::{DType, Result};
|
|||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex, RwLock};
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
|
|
||||||
@ -236,7 +235,7 @@ impl MetalDevice {
|
|||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||||
let new_buffer = self.device.new_buffer_with_data(
|
let new_buffer = self.device.new_buffer_with_data(
|
||||||
data.as_ptr() as *const c_void,
|
data.as_ptr().cast(),
|
||||||
size,
|
size,
|
||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
);
|
);
|
||||||
|
@ -265,6 +265,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
|
|
||||||
let src_stride = layout.stride();
|
let src_stride = layout.stride();
|
||||||
let src_dims = layout.shape().dims();
|
let src_dims = layout.shape().dims();
|
||||||
// Source dims and strides with the sum dims at the end.
|
// Source dims and strides with the sum dims at the end.
|
||||||
@ -278,13 +279,72 @@ impl BackendStorage for MetalStorage {
|
|||||||
stride.push(src_stride[dim_idx]);
|
stride.push(src_stride[dim_idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for &dim_idx in sum_dims.iter() {
|
for &dim_idx in sum_dims.iter() {
|
||||||
dims.push(src_dims[dim_idx]);
|
dims.push(src_dims[dim_idx]);
|
||||||
stride.push(src_stride[dim_idx]);
|
stride.push(src_stride[dim_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The reduction loop requires the shared array to be properly initialized and for
|
let reduction_shape = Shape::from(dims.clone());
|
||||||
// this we want the number of threads to be a power of two.
|
|
||||||
|
if layout.is_contiguous() && reduction_shape.is_contiguous(&stride) {
|
||||||
|
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||||
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||||
|
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||||
|
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||||
|
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||||
|
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||||
|
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||||
|
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||||
|
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||||
|
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||||
|
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||||
|
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||||
|
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||||
|
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||||
|
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||||
|
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||||
|
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||||
|
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||||
|
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||||
|
(k, dtype) => {
|
||||||
|
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
|
}
|
||||||
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
|
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||||
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
name,
|
||||||
|
src_dims,
|
||||||
|
dst_el,
|
||||||
|
src,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
|
return Ok(Self::new(buffer, device, dst_el, dtype));
|
||||||
|
}
|
||||||
|
|
||||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||||
@ -316,7 +376,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
|
@ -5,14 +5,12 @@ use metal::{
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
pub mod mlx_gemm;
|
pub mod mlx_gemm;
|
||||||
pub mod sort;
|
pub mod sort;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub use utils::BufferOffset;
|
|
||||||
|
|
||||||
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
|
pub use mlx_gemm::{call_mlx_gemm, GemmDType};
|
||||||
pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
pub use sort::{call_arg_sort, call_mlx_arg_sort};
|
||||||
|
pub use utils::BufferOffset;
|
||||||
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
@ -176,7 +174,7 @@ pub enum MetalKernelError {
|
|||||||
LockError(String),
|
LockError(String),
|
||||||
#[error("Error while loading library: {0}")]
|
#[error("Error while loading library: {0}")]
|
||||||
LoadLibraryError(String),
|
LoadLibraryError(String),
|
||||||
#[error("Error while loading function: {0:?}")]
|
#[error("Error while loading function: {0}")]
|
||||||
LoadFunctionError(String),
|
LoadFunctionError(String),
|
||||||
#[error("Failed to create compute function")]
|
#[error("Failed to create compute function")]
|
||||||
FailedToCreateComputeFunction,
|
FailedToCreateComputeFunction,
|
||||||
@ -635,19 +633,31 @@ pub fn call_reduce_contiguous(
|
|||||||
ep: impl EncoderProvider,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
shape: &[usize],
|
||||||
out_length: usize,
|
out_length: usize,
|
||||||
input: BufferOffset,
|
input: BufferOffset,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let length = shape.iter().product::<usize>();
|
||||||
|
let num_dims = shape.len();
|
||||||
|
let work_per_threadgroup = length / out_length;
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let elements_to_sum = length / out_length;
|
|
||||||
|
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, elements_to_sum, &input, output));
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
length,
|
||||||
|
num_dims,
|
||||||
|
shape,
|
||||||
|
work_per_threadgroup,
|
||||||
|
&input,
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width: out_length as u64,
|
width: out_length as u64,
|
||||||
@ -657,9 +667,8 @@ pub fn call_reduce_contiguous(
|
|||||||
|
|
||||||
let width = std::cmp::min(
|
let width = std::cmp::min(
|
||||||
pipeline.max_total_threads_per_threadgroup(),
|
pipeline.max_total_threads_per_threadgroup(),
|
||||||
(elements_to_sum as u64).div_ceil(2),
|
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
||||||
)
|
);
|
||||||
.next_power_of_two();
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
let thread_group_size = MTLSize {
|
||||||
width,
|
width,
|
||||||
@ -686,8 +695,9 @@ pub fn call_reduce_strided(
|
|||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
|
let num_dims = shape.len();
|
||||||
|
let work_per_threadgroup = length / out_length;
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let elements_to_sum = length / out_length;
|
|
||||||
|
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
@ -695,7 +705,15 @@ pub fn call_reduce_strided(
|
|||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(shape.len(), shape, strides, elements_to_sum, &input, output)
|
(
|
||||||
|
length,
|
||||||
|
num_dims,
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
work_per_threadgroup,
|
||||||
|
&input,
|
||||||
|
output
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
@ -706,16 +724,14 @@ pub fn call_reduce_strided(
|
|||||||
|
|
||||||
let width = std::cmp::min(
|
let width = std::cmp::min(
|
||||||
pipeline.max_total_threads_per_threadgroup(),
|
pipeline.max_total_threads_per_threadgroup(),
|
||||||
elements_to_sum as u64,
|
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
||||||
)
|
);
|
||||||
.next_power_of_two();
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
let thread_group_size = MTLSize {
|
||||||
width,
|
width,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
@ -729,11 +745,13 @@ pub fn call_last_softmax(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
elements_to_sum: usize,
|
elements: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let work_per_threadgroup = elements;
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
@ -741,29 +759,27 @@ pub fn call_last_softmax(
|
|||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(length, elements_to_sum, (input, input_offset), output)
|
(length, work_per_threadgroup, (input, input_offset), output)
|
||||||
);
|
);
|
||||||
|
|
||||||
let out_length = length / elements_to_sum;
|
let out_length = length / work_per_threadgroup;
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width: out_length as u64,
|
width: out_length as NSUInteger,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let width = std::cmp::min(
|
let width = std::cmp::min(
|
||||||
pipeline.max_total_threads_per_threadgroup(),
|
pipeline.max_total_threads_per_threadgroup(),
|
||||||
elements_to_sum as u64,
|
(work_per_threadgroup / 2).next_power_of_two() as NSUInteger,
|
||||||
)
|
);
|
||||||
.next_power_of_two();
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
let thread_group_size = MTLSize {
|
||||||
width,
|
width,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,8 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal::MTLResourceOptions;
|
use metal::{Buffer, Device, MTLResourceOptions};
|
||||||
|
use rand::prelude::SliceRandom;
|
||||||
|
use rand::thread_rng;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||||
@ -860,7 +862,12 @@ fn cos_f16() {
|
|||||||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
fn run_reduce<T, U: Clone>(
|
||||||
|
v: &[T],
|
||||||
|
in_length: usize,
|
||||||
|
out_length: usize,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Vec<U> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
@ -868,21 +875,24 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
|
||||||
let dims = vec![v.len()];
|
let shape = vec![in_length];
|
||||||
let strides = vec![1];
|
match call_reduce_contiguous(
|
||||||
call_reduce_strided(
|
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
&dims,
|
&shape,
|
||||||
&strides,
|
|
||||||
out_length,
|
out_length,
|
||||||
BufferOffset::zero_offset(&input),
|
BufferOffset::zero_offset(&input),
|
||||||
&output,
|
&output,
|
||||||
)
|
) {
|
||||||
.unwrap();
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
println!("{e}");
|
||||||
|
panic!();
|
||||||
|
}
|
||||||
|
}
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
@ -914,22 +924,187 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
|||||||
read_to_vec(&output, v.len())
|
read_to_vec(&output, v.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
const fn create_array<const N: usize>() -> [f32; N] {
|
||||||
fn reduce_sum() {
|
let mut array: [f32; N] = [0.0; N];
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let mut i = 1;
|
||||||
let out_length = 1;
|
while i <= N {
|
||||||
|
array[i - 1] = i as f32;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
array
|
||||||
|
}
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
|
||||||
assert_eq!(approx(results, 4), vec![21.0]);
|
let mut sum = 0;
|
||||||
|
let mut results: [f32; D] = [0.0; D];
|
||||||
|
let mut i = 1;
|
||||||
|
let mut j = 1;
|
||||||
|
while i <= N {
|
||||||
|
sum += i;
|
||||||
|
i += 1;
|
||||||
|
if i > j * N / D {
|
||||||
|
results[j - 1] = sum as f32;
|
||||||
|
j += 1;
|
||||||
|
sum = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn correct_max<const N: usize, const D: usize>() -> [f32; D] {
|
||||||
|
let mut results: [f32; D] = [0.0; D];
|
||||||
|
let mut i = 1;
|
||||||
|
let mut j = 1;
|
||||||
|
while i <= N {
|
||||||
|
i += 1;
|
||||||
|
if i > j * (N / D) {
|
||||||
|
results[j - 1] = (i - 1) as f32;
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
|
||||||
|
let mut max = 0.0;
|
||||||
|
let mut max_index: u32 = 0;
|
||||||
|
let mut results: [u32; D] = [0; D];
|
||||||
|
let mut i = 0;
|
||||||
|
let mut j = 1;
|
||||||
|
while i <= N {
|
||||||
|
if i >= (j * N / D) {
|
||||||
|
results[j - 1] = max_index;
|
||||||
|
max = 0.0;
|
||||||
|
max_index = 0;
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
if i == N {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if arr[i] > max {
|
||||||
|
max = arr[i];
|
||||||
|
max_index = i as u32;
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_sum_case<const N: usize, const D: usize>() {
|
||||||
|
let mut v = create_array::<N>();
|
||||||
|
if D == 1 {
|
||||||
|
// Hardens 1-dimensional test cases
|
||||||
|
v.shuffle(&mut thread_rng());
|
||||||
|
}
|
||||||
|
let results = run_reduce(&v, N, D, "fast_sum_f32");
|
||||||
|
assert_eq!(approx(results, 4), correct_sum::<N, D>());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_max_case<const N: usize, const D: usize>() {
|
||||||
|
let mut v = create_array::<N>();
|
||||||
|
if D == 1 {
|
||||||
|
// Hardens 1-dimensional test cases
|
||||||
|
v.shuffle(&mut thread_rng());
|
||||||
|
}
|
||||||
|
let results = run_reduce(&v, N, D, "fast_max_f32");
|
||||||
|
assert_eq!(approx(results, 4), correct_max::<N, D>());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_argmax_case<const N: usize, const D: usize>() {
|
||||||
|
let mut v = create_array::<N>();
|
||||||
|
if D == 1 {
|
||||||
|
// Hardens 1-dimensional test cases
|
||||||
|
v.shuffle(&mut thread_rng());
|
||||||
|
}
|
||||||
|
let results: Vec<u32> = run_reduce(&v, N, D, "fast_argmax_f32");
|
||||||
|
assert_eq!(results, correct_argmax::<N, D>(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_sum1() {
|
||||||
|
reduce_sum_case::<9, 1>();
|
||||||
|
reduce_sum_case::<6, 1>();
|
||||||
|
reduce_sum_case::<10, 1>();
|
||||||
|
reduce_sum_case::<64, 1>();
|
||||||
|
reduce_sum_case::<128, 1>();
|
||||||
|
reduce_sum_case::<256, 1>();
|
||||||
|
reduce_sum_case::<512, 1>();
|
||||||
|
reduce_sum_case::<1024, 1>();
|
||||||
|
reduce_sum_case::<2048, 1>();
|
||||||
|
reduce_sum_case::<4096, 1>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reduce_sum2() {
|
fn reduce_sum2() {
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
reduce_sum_case::<6, 2>();
|
||||||
let out_length = 2;
|
reduce_sum_case::<10, 2>();
|
||||||
|
reduce_sum_case::<64, 2>();
|
||||||
|
reduce_sum_case::<128, 2>();
|
||||||
|
reduce_sum_case::<256, 2>();
|
||||||
|
reduce_sum_case::<512, 2>();
|
||||||
|
reduce_sum_case::<1024, 2>();
|
||||||
|
reduce_sum_case::<2048, 2>();
|
||||||
|
reduce_sum_case::<4096, 2>();
|
||||||
|
}
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
#[test]
|
||||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
fn reduce_max() {
|
||||||
|
reduce_max_case::<6, 1>();
|
||||||
|
reduce_max_case::<9, 1>();
|
||||||
|
reduce_max_case::<10, 1>();
|
||||||
|
reduce_max_case::<64, 1>();
|
||||||
|
reduce_max_case::<128, 1>();
|
||||||
|
reduce_max_case::<256, 1>();
|
||||||
|
reduce_max_case::<512, 1>();
|
||||||
|
reduce_max_case::<1024, 1>();
|
||||||
|
reduce_max_case::<2048, 1>();
|
||||||
|
reduce_max_case::<4096, 1>();
|
||||||
|
|
||||||
|
reduce_max_case::<6, 2>();
|
||||||
|
reduce_max_case::<10, 2>();
|
||||||
|
reduce_max_case::<64, 2>();
|
||||||
|
reduce_max_case::<128, 2>();
|
||||||
|
reduce_max_case::<256, 2>();
|
||||||
|
reduce_max_case::<512, 2>();
|
||||||
|
reduce_max_case::<1024, 2>();
|
||||||
|
reduce_max_case::<2048, 2>();
|
||||||
|
reduce_max_case::<4096, 2>();
|
||||||
|
|
||||||
|
reduce_max_case::<6, 3>();
|
||||||
|
reduce_max_case::<10, 3>();
|
||||||
|
reduce_max_case::<64, 3>();
|
||||||
|
reduce_max_case::<128, 3>();
|
||||||
|
reduce_max_case::<256, 3>();
|
||||||
|
reduce_max_case::<512, 3>();
|
||||||
|
reduce_max_case::<1024, 3>();
|
||||||
|
reduce_max_case::<2048, 3>();
|
||||||
|
reduce_max_case::<4096, 3>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_argmax() {
|
||||||
|
reduce_argmax_case::<6, 1>();
|
||||||
|
reduce_argmax_case::<9, 1>();
|
||||||
|
reduce_argmax_case::<10, 1>();
|
||||||
|
reduce_argmax_case::<64, 1>();
|
||||||
|
reduce_argmax_case::<128, 1>();
|
||||||
|
reduce_argmax_case::<256, 1>();
|
||||||
|
reduce_argmax_case::<512, 1>();
|
||||||
|
reduce_argmax_case::<1024, 1>();
|
||||||
|
reduce_argmax_case::<2048, 1>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_argmax2() {
|
||||||
|
reduce_argmax_case::<6, 2>();
|
||||||
|
reduce_argmax_case::<10, 2>();
|
||||||
|
reduce_argmax_case::<64, 2>();
|
||||||
|
reduce_argmax_case::<128, 2>();
|
||||||
|
reduce_argmax_case::<256, 2>();
|
||||||
|
reduce_argmax_case::<512, 2>();
|
||||||
|
reduce_argmax_case::<1024, 2>();
|
||||||
|
reduce_argmax_case::<2048, 2>();
|
||||||
|
reduce_argmax_case::<4096, 2>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -983,7 +1158,7 @@ fn softmax() {
|
|||||||
let results = run_softmax(&v, last_dim, "softmax_f16");
|
let results = run_softmax(&v, last_dim, "softmax_f16");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
approx_f16(results, 4),
|
approx_f16(results, 4),
|
||||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2332, 0.6338]
|
||||||
);
|
);
|
||||||
|
|
||||||
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||||
|
47
candle-metal-kernels/src/utils.metal
Normal file
47
candle-metal-kernels/src/utils.metal
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
METAL_FUNC uint nonzero(uint n) {
|
||||||
|
return n == 0 ? 1 : n;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<uint N>
|
||||||
|
constexpr uint nonzero() {
|
||||||
|
return N == 0 ? 1 : N;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
constexpr ushort granularity() {
|
||||||
|
return nonzero<vec_elements<T>::value>();
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint next_p2(uint x) {
|
||||||
|
return 1 << (32 - clz(x - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint prev_p2(uint x) {
|
||||||
|
return 1 << (31 - clz(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
constant uint MAX_SHARED_MEM = 32767;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
METAL_FUNC uint max_shared_mem(uint n) {
|
||||||
|
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant const uint &num_dims,
|
||||||
|
constant const size_t *dims,
|
||||||
|
constant const size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
@ -1,4 +1,8 @@
|
|||||||
mod benchmarks;
|
mod benchmarks;
|
||||||
|
|
||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
criterion_main!(benchmarks::layer_norm::benches, benchmarks::conv::benches);
|
criterion_main!(
|
||||||
|
benchmarks::softmax::benches,
|
||||||
|
benchmarks::layer_norm::benches,
|
||||||
|
benchmarks::conv::benches
|
||||||
|
);
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
pub(crate) mod conv;
|
pub(crate) mod conv;
|
||||||
pub(crate) mod layer_norm;
|
pub(crate) mod layer_norm;
|
||||||
|
pub(crate) mod softmax;
|
||||||
|
|
||||||
use candle::{Device, Result};
|
use candle::{Device, Result};
|
||||||
|
|
||||||
|
49
candle-nn/benches/benchmarks/softmax.rs
Normal file
49
candle-nn/benches/benchmarks/softmax.rs
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_nn::ops::softmax_last_dim;
|
||||||
|
use criterion::Throughput;
|
||||||
|
use criterion::{black_box, criterion_group, Criterion};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
fn run(input: &Tensor) {
|
||||||
|
let _ = softmax_last_dim(&input).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
const B: usize = 1;
|
||||||
|
const M: usize = 1024;
|
||||||
|
const K: usize = 1024;
|
||||||
|
|
||||||
|
fn run_softmax_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
|
let elements = B * M * K;
|
||||||
|
|
||||||
|
let input = Tensor::rand(-1000.0f32, 1000.0f32, (B, M, K), &device)
|
||||||
|
.unwrap()
|
||||||
|
.to_dtype(dtype)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let flops = elements * dtype.size_in_bytes();
|
||||||
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&input));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let device = BenchDeviceHandler::new().unwrap();
|
||||||
|
for d in device.devices {
|
||||||
|
run_softmax_benchmark(c, &d, DType::F32, "softmax_f32");
|
||||||
|
run_softmax_benchmark(c, &d, DType::BF16, "softmax_bf16");
|
||||||
|
run_softmax_benchmark(c, &d, DType::F16, "softmax_f16");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
Reference in New Issue
Block a user