Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)

* add basic unary bench for sqrt

* process unary commands in tiles of 4

* re-enable all benchmarks

* rename helper to unary

* modify approach to split up tiled and non-tiled operations

* undo bench ignore for other tests

* update tile size to 2

* only perform the optimization on the contiguous even numbered element case
This commit is contained in:
Thomas Santerre
2024-04-20 18:10:33 -04:00
committed by GitHub
parent 587ee3bb6f
commit 0067fe00a8
6 changed files with 380 additions and 184 deletions

View File

@ -8,4 +8,5 @@ criterion_main!(
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -3,6 +3,7 @@ pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod unary;
pub(crate) mod where_cond;
use candle_core::{Device, Result};

View File

@ -0,0 +1,49 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.sqrt().unwrap();
}
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
.unwrap()
.to_dtype(dtype)
.unwrap()
.reshape((b, m, k))
.unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [DType::F32, DType::BF16, DType::F16] {
let name = format!("sqrt_{:?}", dtype);
run_unary_benchmark(c, &device, dtype, &name);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -444,9 +444,88 @@ impl BackendStorage for MetalStorage {
let command_buffer = device.command_buffer()?;
command_buffer.set_label(B::KERNEL);
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
match (el_count % 2, dtype, layout.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match (B::KERNEL, dtype) {
("uabs", DType::F16) => contiguous_tiled::abs::HALF,
("uabs", DType::F32) => contiguous_tiled::abs::FLOAT,
("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT,
("uceil", DType::F16) => contiguous_tiled::ceil::HALF,
("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT,
("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT,
("ucos", DType::F16) => contiguous_tiled::cos::HALF,
("ucos", DType::F32) => contiguous_tiled::cos::FLOAT,
("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT,
("uerf", DType::F16) => contiguous_tiled::erf::HALF,
("uerf", DType::F32) => contiguous_tiled::erf::FLOAT,
("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT,
("uexp", DType::F16) => contiguous_tiled::exp::HALF,
("uexp", DType::F32) => contiguous_tiled::exp::FLOAT,
("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT,
("ufloor", DType::F16) => contiguous_tiled::floor::HALF,
("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT,
("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT,
("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF,
("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT,
("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT,
("ugelu", DType::F16) => contiguous_tiled::gelu::HALF,
("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT,
("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT,
("ulog", DType::F16) => contiguous_tiled::log::HALF,
("ulog", DType::F32) => contiguous_tiled::log::FLOAT,
("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT,
("uneg", DType::F16) => contiguous_tiled::neg::HALF,
("uneg", DType::F32) => contiguous_tiled::neg::FLOAT,
("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT,
("urecip", DType::F16) => contiguous_tiled::recip::HALF,
("urecip", DType::F32) => contiguous_tiled::recip::FLOAT,
("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT,
("urelu", DType::F16) => contiguous_tiled::relu::HALF,
("urelu", DType::F32) => contiguous_tiled::relu::FLOAT,
("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT,
("uround", DType::F16) => contiguous_tiled::round::HALF,
("uround", DType::F32) => contiguous_tiled::round::FLOAT,
("uround", DType::BF16) => contiguous_tiled::round::BFLOAT,
("usilu", DType::F16) => contiguous_tiled::silu::HALF,
("usilu", DType::F32) => contiguous_tiled::silu::FLOAT,
("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT,
("usin", DType::F16) => contiguous_tiled::sin::HALF,
("usin", DType::F32) => contiguous_tiled::sin::FLOAT,
("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT,
("usqr", DType::F16) => contiguous_tiled::sqr::HALF,
("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT,
("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT,
("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF,
("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT,
("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT,
("utanh", DType::F16) => contiguous_tiled::tanh::HALF,
("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT,
("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT,
("usign", DType::F16) => contiguous_tiled::sign::HALF,
("usign", DType::F32) => contiguous_tiled::sign::FLOAT,
("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT,
("usign", DType::I64) => contiguous_tiled::sign::I64,
(name, dtype) => {
crate::bail!(
"Metal contiguous_tiled unary {name} {dtype:?} not implemented"
)
}
};
candle_metal_kernels::call_unary_contiguous_tiled(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("uabs", DType::F16) => contiguous::abs::HALF,
("uabs", DType::F32) => contiguous::abs::FLOAT,
@ -520,7 +599,8 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
} else {
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
@ -594,6 +674,11 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
}
if layout.is_contiguous() {
} else {
}
Ok(Self::new(buffer, device.clone(), el_count, dtype))
}

View File

@ -74,6 +74,30 @@ macro_rules! ops{
}
}
pub mod contiguous_tiled {
pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled"));
}
)+
pub mod copy {
use super::Kernel;
pub const FLOAT: Kernel = Kernel("copy_f32_tiled");
pub const HALF: Kernel = Kernel("copy_f16_tiled");
pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled");
pub const I64: Kernel = Kernel("copy_i64_tiled");
pub const U32: Kernel = Kernel("copy_u32_tiled");
pub const U8: Kernel = Kernel("copy_u8_tiled");
}
}
pub mod strided {
pub struct Kernel(pub &'static str);
$(
@ -267,30 +291,6 @@ impl Kernels {
}
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_copy2d(
device: &Device,
@ -334,6 +334,58 @@ pub fn call_copy2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous_tiled(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: unary::contiguous_tiled::Kernel,
length: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
let tile_size = 2;
let tiles = length.div_ceil(tile_size);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
@ -347,16 +399,13 @@ pub fn call_unary_strided(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
@ -410,10 +459,10 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
@ -427,14 +476,12 @@ pub fn call_binary_strided(
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}

View File

@ -68,6 +68,8 @@ template <typename T> METAL_FUNC T silu(T in){
return in / (static_cast<T>(1) + exp(-in));
}
#define TILE_SIZE 2
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@ -79,8 +81,8 @@ kernel void FN_NAME( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[tid]))); \
}\
kernel void FN_NAME_STRIDED( \
} \
kernel void FN_NAME##_##strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
@ -93,6 +95,17 @@ kernel void FN_NAME_STRIDED( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
} \
kernel void FN_NAME##_##tiled( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
for (uint i = 0; i < TILE_SIZE; i++) { \
const uint idx = tid * TILE_SIZE + i; \
output[idx] = TYPENAME(FN(float(input[idx]))); \
} \
}
#define UNARY_OP(NAME) \