From e8e24f1284decb5b56a1e3f3ff41e49860d01244 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:37:56 +0100 Subject: [PATCH] Follow crate conventions --- candle-core/benches/fill.rs | 16 ++++- candle-core/src/metal_backend.rs | 20 ++++-- candle-metal-kernels/Cargo.toml | 4 -- candle-metal-kernels/src/lib.rs | 110 +++++++++++++----------------- candle-metal-kernels/src/tests.rs | 2 +- 5 files changed, 75 insertions(+), 77 deletions(-) diff --git a/candle-core/benches/fill.rs b/candle-core/benches/fill.rs index 9bcb4775..9bd0aa72 100644 --- a/candle-core/benches/fill.rs +++ b/candle-core/benches/fill.rs @@ -22,7 +22,11 @@ fn criterion_benchmark(c: &mut Criterion) { bencher.iter_custom(|iters| { let start = Instant::now(); for _i in 0..iters { - run(black_box((b, rows, columns)), black_box(DType::U8), black_box(&device1)); + run( + black_box((b, rows, columns)), + black_box(DType::U8), + black_box(&device1), + ); } if let Device::Metal(device) = &device1 { device.wait_until_completed().unwrap(); @@ -35,12 +39,18 @@ fn criterion_benchmark(c: &mut Criterion) { group.finish(); let mut group = c.benchmark_group("fill_metal_f32"); - group.throughput(Throughput::Bytes((flops * DType::F32.size_in_bytes()) as u64)); + group.throughput(Throughput::Bytes( + (flops * DType::F32.size_in_bytes()) as u64, + )); group.bench_function("iter", move |bencher| { bencher.iter_custom(|iters| { let start = Instant::now(); for _i in 0..iters { - run(black_box((b, rows, columns)), black_box(DType::F32), black_box(&device2)); + run( + black_box((b, rows, columns)), + black_box(DType::F32), + black_box(&device2), + ); } if let Device::Metal(device) = &device2 { device.wait_until_completed().unwrap(); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 21eb1336..3f6060ce 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvT use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::{FillOp, Unary, Kernels}; +use candle_metal_kernels::Kernels; use half::{bf16, f16}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; @@ -1405,15 +1405,14 @@ impl BackendDevice for MetalDevice { let command_buffer = self.command_buffer()?; command_buffer.set_label("zeros"); - // This assumes the zero value of this DType is equal to 0x00u8 + // This kernel assumes the zero value of this DType is equal to 0x00u8 // (which is true for all current types) - Unary::fill( - &self.device, + candle_metal_kernels::call_fill_u8( &command_buffer, &self.kernels, shape.elem_count(), &buffer, - 0u8, + 0, ) .map_err(MetalError::from)?; @@ -1427,7 +1426,7 @@ impl BackendDevice for MetalDevice { macro_rules! fill { ($value:expr) => { - Unary::fill( + candle_metal_kernels::call_fill( &self.device, &command_buffer, &self.kernels, @@ -1439,7 +1438,14 @@ impl BackendDevice for MetalDevice { }; } match dtype { - DType::U8 => fill!(1u8), + DType::U8 => candle_metal_kernels::call_fill_u8( + &command_buffer, + &self.kernels, + shape.elem_count(), + &buffer, + 1u8, + ) + .map_err(MetalError::from)?, DType::U32 => fill!(1u32), DType::I64 => fill!(1i64), DType::BF16 => fill!(bf16::ONE), diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 25446d29..162adbd7 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -20,7 +20,3 @@ num-traits = "0.2.17" [dev-dependencies] rand = "0.8.5" criterion = "0.5.1" - -[[bench]] -name = "fill" -harness = false diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f5b0653b..5db4c2cb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -5,7 +5,6 @@ use metal::{ }; use std::collections::HashMap; use std::ffi::c_void; -use std::marker::PhantomData; use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); @@ -1578,81 +1577,68 @@ fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } -pub struct Unary { - _marker: PhantomData, +pub fn call_fill( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: T, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, T::FILL_KERNEL)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); + + set_params!(encoder, (buffer, value, elem_count)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) } -pub trait FillOp { - const FILL_KERNEL: &'static str; +pub fn call_fill_u8( + command_buffer: &CommandBufferRef, + kernels: &Kernels, + elem_count: usize, + buffer: &Buffer, + value: u8, +) -> Result<(), MetalKernelError> { + let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&kernels.fence); + blit.fill_buffer( + buffer, + metal::NSRange { + location: 0, + length: elem_count as NSUInteger, + }, + value, + ); + blit.update_fence(&kernels.fence); + blit.end_encoding(); - fn fill( - device: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - elem_count: usize, - buffer: &Buffer, - value: T, - ) -> Result<(), MetalKernelError>; + Ok(()) +} + +pub trait FillOp: EncoderParam { + const FILL_KERNEL: &'static str; } macro_rules ! impl_call_fill { ($($t:ty),*) => { $( - impl FillOp<$t> for Unary<$t> { + impl FillOp for $t { const FILL_KERNEL: &'static str = concat!("fill_", stringify!($t)); - - #[inline(always)] - fn fill(device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, elem_count: usize, buffer: &Buffer, value: $t) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Fill, Self::FILL_KERNEL)?; - let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); - - set_params!(encoder, (buffer, value, elem_count)); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.use_resource(buffer, metal::MTLResourceUsage::Write); - encoder.update_fence(&kernels.fence); - encoder.end_encoding(); - - Ok(()) - } } )* }; } impl_call_fill!(u32, i64, f16, bf16, f32); -impl FillOp for Unary { - const FILL_KERNEL: &'static str = ""; - - #[inline(always)] - fn fill( - _: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - elem_count: usize, - buffer: &Buffer, - value: u8, - ) -> Result<(), MetalKernelError> { - let blit = command_buffer.new_blit_command_encoder(); - blit.wait_for_fence(&kernels.fence); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: elem_count as NSUInteger, - }, - value, - ); - blit.update_fence(&kernels.fence); - blit.end_encoding(); - - Ok(()) - } -} - #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b7bff740..4b27d163 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -847,4 +847,4 @@ fn fill() { assert_fill(f16::from_f32(1.23)); assert_fill(bf16::from_f32(4.56)); assert_fill(7.89f32); -} \ No newline at end of file +}