From 0067fe00a8477b8c817dcf54d4d4084b07b7fc5b Mon Sep 17 00:00:00 2001 From: Thomas Santerre Date: Sat, 20 Apr 2024 18:10:33 -0400 Subject: [PATCH] Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056) * add basic unary bench for sqrt * process unary commands in tiles of 4 * re-enable all benchmarks * rename helper to unary * modify approach to split up tiled and non-tiled operations * undo bench ignore for other tests * update tile size to 2 * only perform the optimization on the contiguous even numbered element case --- candle-core/benches/bench_main.rs | 1 + candle-core/benches/benchmarks/mod.rs | 1 + candle-core/benches/benchmarks/unary.rs | 49 +++ candle-core/src/metal_backend/mod.rs | 379 +++++++++++++++--------- candle-metal-kernels/src/lib.rs | 117 +++++--- candle-metal-kernels/src/unary.metal | 17 +- 6 files changed, 380 insertions(+), 184 deletions(-) create mode 100644 candle-core/benches/benchmarks/unary.rs diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index b0c0b05c..2e1816fd 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -8,4 +8,5 @@ criterion_main!( benchmarks::where_cond::benches, benchmarks::conv_transpose2d::benches, benchmarks::qmatmul::benches, + benchmarks::unary::benches ); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 778034ea..579c5f3f 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d; pub(crate) mod matmul; pub(crate) mod qmatmul; pub(crate) mod random; +pub(crate) mod unary; pub(crate) mod where_cond; use candle_core::{Device, Result}; diff --git a/candle-core/benches/benchmarks/unary.rs b/candle-core/benches/benchmarks/unary.rs new file mode 100644 index 00000000..a8e0d025 --- /dev/null +++ b/candle-core/benches/benchmarks/unary.rs @@ -0,0 +1,49 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.sqrt().unwrap(); +} + +fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device) + .unwrap() + .to_dtype(dtype) + .unwrap() + .reshape((b, m, k)) + .unwrap(); + + let flops = b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + for dtype in [DType::F32, DType::BF16, DType::F16] { + let name = format!("sqrt_{:?}", dtype); + run_unary_benchmark(c, &device, dtype, &name); + } + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index daa68c39..12dba381 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -444,155 +444,240 @@ impl BackendStorage for MetalStorage { let command_buffer = device.command_buffer()?; command_buffer.set_label(B::KERNEL); let src = buffer_o(&self.buffer, layout, self.dtype); + + match (el_count % 2, dtype, layout.is_contiguous()) { + (0, DType::BF16 | DType::F16, true) => { + use candle_metal_kernels::unary::contiguous_tiled; + let kernel_name = match (B::KERNEL, dtype) { + ("uabs", DType::F16) => contiguous_tiled::abs::HALF, + ("uabs", DType::F32) => contiguous_tiled::abs::FLOAT, + ("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT, + ("uceil", DType::F16) => contiguous_tiled::ceil::HALF, + ("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT, + ("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT, + ("ucos", DType::F16) => contiguous_tiled::cos::HALF, + ("ucos", DType::F32) => contiguous_tiled::cos::FLOAT, + ("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT, + ("uerf", DType::F16) => contiguous_tiled::erf::HALF, + ("uerf", DType::F32) => contiguous_tiled::erf::FLOAT, + ("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT, + ("uexp", DType::F16) => contiguous_tiled::exp::HALF, + ("uexp", DType::F32) => contiguous_tiled::exp::FLOAT, + ("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT, + ("ufloor", DType::F16) => contiguous_tiled::floor::HALF, + ("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT, + ("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT, + ("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF, + ("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT, + ("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT, + ("ugelu", DType::F16) => contiguous_tiled::gelu::HALF, + ("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT, + ("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT, + ("ulog", DType::F16) => contiguous_tiled::log::HALF, + ("ulog", DType::F32) => contiguous_tiled::log::FLOAT, + ("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT, + ("uneg", DType::F16) => contiguous_tiled::neg::HALF, + ("uneg", DType::F32) => contiguous_tiled::neg::FLOAT, + ("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT, + ("urecip", DType::F16) => contiguous_tiled::recip::HALF, + ("urecip", DType::F32) => contiguous_tiled::recip::FLOAT, + ("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT, + ("urelu", DType::F16) => contiguous_tiled::relu::HALF, + ("urelu", DType::F32) => contiguous_tiled::relu::FLOAT, + ("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT, + ("uround", DType::F16) => contiguous_tiled::round::HALF, + ("uround", DType::F32) => contiguous_tiled::round::FLOAT, + ("uround", DType::BF16) => contiguous_tiled::round::BFLOAT, + ("usilu", DType::F16) => contiguous_tiled::silu::HALF, + ("usilu", DType::F32) => contiguous_tiled::silu::FLOAT, + ("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT, + ("usin", DType::F16) => contiguous_tiled::sin::HALF, + ("usin", DType::F32) => contiguous_tiled::sin::FLOAT, + ("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT, + ("usqr", DType::F16) => contiguous_tiled::sqr::HALF, + ("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT, + ("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT, + ("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF, + ("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT, + ("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT, + ("utanh", DType::F16) => contiguous_tiled::tanh::HALF, + ("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT, + ("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT, + ("usign", DType::F16) => contiguous_tiled::sign::HALF, + ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, + ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, + ("usign", DType::I64) => contiguous_tiled::sign::I64, + (name, dtype) => { + crate::bail!( + "Metal contiguous_tiled unary {name} {dtype:?} not implemented" + ) + } + }; + candle_metal_kernels::call_unary_contiguous_tiled( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, true) => { + use candle_metal_kernels::unary::contiguous; + let kernel_name = match (B::KERNEL, dtype) { + ("uabs", DType::F16) => contiguous::abs::HALF, + ("uabs", DType::F32) => contiguous::abs::FLOAT, + ("uabs", DType::BF16) => contiguous::abs::BFLOAT, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("uceil", DType::BF16) => contiguous::ceil::BFLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("ucos", DType::BF16) => contiguous::cos::BFLOAT, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uerf", DType::BF16) => contiguous::erf::BFLOAT, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("uexp", DType::BF16) => contiguous::exp::BFLOAT, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("ufloor", DType::BF16) => contiguous::floor::BFLOAT, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ulog", DType::BF16) => contiguous::log::BFLOAT, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uneg", DType::BF16) => contiguous::neg::BFLOAT, + ("urecip", DType::F16) => contiguous::recip::HALF, + ("urecip", DType::F32) => contiguous::recip::FLOAT, + ("urecip", DType::BF16) => contiguous::recip::BFLOAT, + ("urelu", DType::F16) => contiguous::relu::HALF, + ("urelu", DType::F32) => contiguous::relu::FLOAT, + ("urelu", DType::BF16) => contiguous::relu::BFLOAT, + ("uround", DType::F16) => contiguous::round::HALF, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("uround", DType::BF16) => contiguous::round::BFLOAT, + ("usilu", DType::F16) => contiguous::silu::HALF, + ("usilu", DType::F32) => contiguous::silu::FLOAT, + ("usilu", DType::BF16) => contiguous::silu::BFLOAT, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usin", DType::BF16) => contiguous::sin::BFLOAT, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqr", DType::BF16) => contiguous::sqr::BFLOAT, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT, + ("utanh", DType::F16) => contiguous::tanh::HALF, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, + ("usign", DType::F16) => contiguous::sign::HALF, + ("usign", DType::F32) => contiguous::sign::FLOAT, + ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I64) => contiguous::sign::I64, + (name, dtype) => { + crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") + } + }; + candle_metal_kernels::call_unary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } + (_, _, false) => { + use candle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("usilu", DType::F32) => strided::silu::FLOAT, + ("uabs", DType::F32) => strided::abs::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("urelu", DType::F32) => strided::relu::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("utanh", DType::F32) => strided::tanh::FLOAT, + + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("usilu", DType::F16) => strided::silu::HALF, + ("uabs", DType::F16) => strided::abs::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("urelu", DType::F16) => strided::relu::HALF, + ("uround", DType::F16) => strided::round::HALF, + ("utanh", DType::F16) => strided::tanh::HALF, + + ("ucos", DType::BF16) => strided::cos::BFLOAT, + ("usin", DType::BF16) => strided::sin::BFLOAT, + ("usqr", DType::BF16) => strided::sqr::BFLOAT, + ("usqrt", DType::BF16) => strided::sqrt::BFLOAT, + ("uneg", DType::BF16) => strided::neg::BFLOAT, + ("uexp", DType::BF16) => strided::exp::BFLOAT, + ("ulog", DType::BF16) => strided::log::BFLOAT, + ("ugelu", DType::BF16) => strided::gelu::BFLOAT, + ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT, + ("uerf", DType::BF16) => strided::erf::BFLOAT, + ("usilu", DType::BF16) => strided::silu::BFLOAT, + ("uabs", DType::BF16) => strided::abs::BFLOAT, + ("uceil", DType::BF16) => strided::ceil::BFLOAT, + ("ufloor", DType::BF16) => strided::floor::BFLOAT, + ("urelu", DType::BF16) => strided::relu::BFLOAT, + ("uround", DType::BF16) => strided::round::BFLOAT, + ("utanh", DType::BF16) => strided::tanh::BFLOAT, + + (name, dtype) => { + crate::bail!("Metal strided unary {name} {dtype:?} not implemented") + } + }; + let dst = BufferOffset::zero_offset(&buffer); + candle_metal_kernels::call_unary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + src, + layout.stride(), + dst, + ) + .map_err(MetalError::from)?; + } + } + if layout.is_contiguous() { - use candle_metal_kernels::unary::contiguous; - - let kernel_name = match (B::KERNEL, dtype) { - ("uabs", DType::F16) => contiguous::abs::HALF, - ("uabs", DType::F32) => contiguous::abs::FLOAT, - ("uabs", DType::BF16) => contiguous::abs::BFLOAT, - ("uceil", DType::F16) => contiguous::ceil::HALF, - ("uceil", DType::F32) => contiguous::ceil::FLOAT, - ("uceil", DType::BF16) => contiguous::ceil::BFLOAT, - ("ucos", DType::F16) => contiguous::cos::HALF, - ("ucos", DType::F32) => contiguous::cos::FLOAT, - ("ucos", DType::BF16) => contiguous::cos::BFLOAT, - ("uerf", DType::F16) => contiguous::erf::HALF, - ("uerf", DType::F32) => contiguous::erf::FLOAT, - ("uerf", DType::BF16) => contiguous::erf::BFLOAT, - ("uexp", DType::F16) => contiguous::exp::HALF, - ("uexp", DType::F32) => contiguous::exp::FLOAT, - ("uexp", DType::BF16) => contiguous::exp::BFLOAT, - ("ufloor", DType::F16) => contiguous::floor::HALF, - ("ufloor", DType::F32) => contiguous::floor::FLOAT, - ("ufloor", DType::BF16) => contiguous::floor::BFLOAT, - ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, - ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, - ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT, - ("ugelu", DType::F16) => contiguous::gelu::HALF, - ("ugelu", DType::F32) => contiguous::gelu::FLOAT, - ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT, - ("ulog", DType::F16) => contiguous::log::HALF, - ("ulog", DType::F32) => contiguous::log::FLOAT, - ("ulog", DType::BF16) => contiguous::log::BFLOAT, - ("uneg", DType::F16) => contiguous::neg::HALF, - ("uneg", DType::F32) => contiguous::neg::FLOAT, - ("uneg", DType::BF16) => contiguous::neg::BFLOAT, - ("urecip", DType::F16) => contiguous::recip::HALF, - ("urecip", DType::F32) => contiguous::recip::FLOAT, - ("urecip", DType::BF16) => contiguous::recip::BFLOAT, - ("urelu", DType::F16) => contiguous::relu::HALF, - ("urelu", DType::F32) => contiguous::relu::FLOAT, - ("urelu", DType::BF16) => contiguous::relu::BFLOAT, - ("uround", DType::F16) => contiguous::round::HALF, - ("uround", DType::F32) => contiguous::round::FLOAT, - ("uround", DType::BF16) => contiguous::round::BFLOAT, - ("usilu", DType::F16) => contiguous::silu::HALF, - ("usilu", DType::F32) => contiguous::silu::FLOAT, - ("usilu", DType::BF16) => contiguous::silu::BFLOAT, - ("usin", DType::F16) => contiguous::sin::HALF, - ("usin", DType::F32) => contiguous::sin::FLOAT, - ("usin", DType::BF16) => contiguous::sin::BFLOAT, - ("usqr", DType::F16) => contiguous::sqr::HALF, - ("usqr", DType::F32) => contiguous::sqr::FLOAT, - ("usqr", DType::BF16) => contiguous::sqr::BFLOAT, - ("usqrt", DType::F16) => contiguous::sqrt::HALF, - ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, - ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT, - ("utanh", DType::F16) => contiguous::tanh::HALF, - ("utanh", DType::F32) => contiguous::tanh::FLOAT, - ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, - ("usign", DType::F16) => contiguous::sign::HALF, - ("usign", DType::F32) => contiguous::sign::FLOAT, - ("usign", DType::BF16) => contiguous::sign::BFLOAT, - ("usign", DType::I64) => contiguous::sign::I64, - (name, dtype) => { - crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") - } - }; - candle_metal_kernels::call_unary_contiguous( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - el_count, - src, - &buffer, - ) - .map_err(MetalError::from)?; } else { - use candle_metal_kernels::unary::strided; - let kernel_name = match (B::KERNEL, dtype) { - ("ucos", DType::F32) => strided::cos::FLOAT, - ("usin", DType::F32) => strided::sin::FLOAT, - ("usqr", DType::F32) => strided::sqr::FLOAT, - ("usqrt", DType::F32) => strided::sqrt::FLOAT, - ("uneg", DType::F32) => strided::neg::FLOAT, - ("uexp", DType::F32) => strided::exp::FLOAT, - ("ulog", DType::F32) => strided::log::FLOAT, - ("ugelu", DType::F32) => strided::gelu::FLOAT, - ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, - ("uerf", DType::F32) => strided::erf::FLOAT, - ("usilu", DType::F32) => strided::silu::FLOAT, - ("uabs", DType::F32) => strided::abs::FLOAT, - ("uceil", DType::F32) => strided::ceil::FLOAT, - ("ufloor", DType::F32) => strided::floor::FLOAT, - ("urelu", DType::F32) => strided::relu::FLOAT, - ("uround", DType::F32) => strided::round::FLOAT, - ("utanh", DType::F32) => strided::tanh::FLOAT, - - ("ucos", DType::F16) => strided::cos::HALF, - ("usin", DType::F16) => strided::sin::HALF, - ("usqr", DType::F16) => strided::sqr::HALF, - ("usqrt", DType::F16) => strided::sqrt::HALF, - ("uneg", DType::F16) => strided::neg::HALF, - ("uexp", DType::F16) => strided::exp::HALF, - ("ulog", DType::F16) => strided::log::HALF, - ("ugelu", DType::F16) => strided::gelu::HALF, - ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, - ("uerf", DType::F16) => strided::erf::HALF, - ("usilu", DType::F16) => strided::silu::HALF, - ("uabs", DType::F16) => strided::abs::HALF, - ("uceil", DType::F16) => strided::ceil::HALF, - ("ufloor", DType::F16) => strided::floor::HALF, - ("urelu", DType::F16) => strided::relu::HALF, - ("uround", DType::F16) => strided::round::HALF, - ("utanh", DType::F16) => strided::tanh::HALF, - - ("ucos", DType::BF16) => strided::cos::BFLOAT, - ("usin", DType::BF16) => strided::sin::BFLOAT, - ("usqr", DType::BF16) => strided::sqr::BFLOAT, - ("usqrt", DType::BF16) => strided::sqrt::BFLOAT, - ("uneg", DType::BF16) => strided::neg::BFLOAT, - ("uexp", DType::BF16) => strided::exp::BFLOAT, - ("ulog", DType::BF16) => strided::log::BFLOAT, - ("ugelu", DType::BF16) => strided::gelu::BFLOAT, - ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT, - ("uerf", DType::BF16) => strided::erf::BFLOAT, - ("usilu", DType::BF16) => strided::silu::BFLOAT, - ("uabs", DType::BF16) => strided::abs::BFLOAT, - ("uceil", DType::BF16) => strided::ceil::BFLOAT, - ("ufloor", DType::BF16) => strided::floor::BFLOAT, - ("urelu", DType::BF16) => strided::relu::BFLOAT, - ("uround", DType::BF16) => strided::round::BFLOAT, - ("utanh", DType::BF16) => strided::tanh::BFLOAT, - - (name, dtype) => { - crate::bail!("Metal strided unary {name} {dtype:?} not implemented") - } - }; - let dst = BufferOffset::zero_offset(&buffer); - candle_metal_kernels::call_unary_strided( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - layout.dims(), - src, - layout.stride(), - dst, - ) - .map_err(MetalError::from)?; } Ok(Self::new(buffer, device.clone(), el_count, dtype)) } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e05797a2..10f942b4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -74,6 +74,30 @@ macro_rules! ops{ } } + pub mod contiguous_tiled { + pub struct Kernel(pub &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); + } + )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32_tiled"); + pub const HALF: Kernel = Kernel("copy_f16_tiled"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); + pub const I64: Kernel = Kernel("copy_i64_tiled"); + pub const U32: Kernel = Kernel("copy_u32_tiled"); + pub const U8: Kernel = Kernel("copy_u8_tiled"); + } + } + pub mod strided { pub struct Kernel(pub &'static str); $( @@ -267,30 +291,6 @@ impl Kernels { } } -#[allow(clippy::too_many_arguments)] -pub fn call_unary_contiguous( - device: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - kernel_name: unary::contiguous::Kernel, - length: usize, - input: BufferOffset, - output: &Buffer, -) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; - let encoder = command_buffer.new_compute_command_encoder(); - encoder.set_compute_pipeline_state(&pipeline); - - set_params!(encoder, (length, &input, output)); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); - Ok(()) -} - #[allow(clippy::too_many_arguments)] pub fn call_copy2d( device: &Device, @@ -334,6 +334,58 @@ pub fn call_copy2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_unary_contiguous_tiled( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: unary::contiguous_tiled::Kernel, + length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + let tile_size = 2; + let tiles = length.div_ceil(tile_size); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_unary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, @@ -347,16 +399,13 @@ pub fn call_unary_strided( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + let length: usize = shape.iter().product(); let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.set_compute_pipeline_state(&pipeline); - - let length: usize = shape.iter().product(); set_params!(encoder, (length, num_dims, shape, strides, &input, &output)); - - let width: usize = shape.iter().product(); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); @@ -410,10 +459,10 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); - encoder.set_compute_pipeline_state(&pipeline); - let length: usize = shape.iter().product(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( @@ -427,14 +476,12 @@ pub fn call_binary_strided( output ) ); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); + Ok(()) } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index ec793eae..143e9500 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -68,6 +68,8 @@ template METAL_FUNC T silu(T in){ return in / (static_cast(1) + exp(-in)); } +#define TILE_SIZE 2 + #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -79,8 +81,8 @@ kernel void FN_NAME( \ return; \ } \ output[tid] = TYPENAME(FN(float(input[tid]))); \ -}\ -kernel void FN_NAME_STRIDED( \ +} \ +kernel void FN_NAME##_##strided( \ constant size_t &dim, \ constant size_t &num_dims, \ constant size_t *dims, \ @@ -93,6 +95,17 @@ kernel void FN_NAME_STRIDED( \ return; \ } \ output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \ +} \ +kernel void FN_NAME##_##tiled( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + for (uint i = 0; i < TILE_SIZE; i++) { \ + const uint idx = tid * TILE_SIZE + i; \ + output[idx] = TYPENAME(FN(float(input[idx]))); \ + } \ } #define UNARY_OP(NAME) \