From 0a29d2e9b85c3cb7aff4fff47ed94146dcf8cdb1 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 29 Dec 2023 12:27:12 +0100 Subject: [PATCH] Add fill kernel handler --- candle-metal-kernels/Cargo.toml | 2 +- candle-metal-kernels/src/lib.rs | 37 +++++++++++++++++++++++- candle-metal-kernels/src/tests.rs | 47 ++++++++++++++++++++++++++++++- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 441d2e88..ba09ffcb 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -14,7 +14,7 @@ metal = { version = "0.27.0", features = ["mps"]} once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" +half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } [dev-dependencies] -half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } rand = "0.8.5" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index dd97a86d..a2730632 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,3 +1,4 @@ +use half::{bf16, f16}; use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, @@ -12,6 +13,7 @@ const UNARY: &str = include_str!("unary.metal"); const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); +const FILL: &str = include_str!("fill.metal"); const REDUCE: &str = include_str!("reduce.metal"); const CONV: &str = include_str!("conv.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); @@ -45,7 +47,7 @@ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -trait EncoderParam { +pub trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { @@ -62,7 +64,11 @@ macro_rules! primitive { }; } primitive!(usize); +primitive!(u8); primitive!(u32); +primitive!(i64); +primitive!(f16); +primitive!(bf16); primitive!(f32); impl EncoderParam for &[T] { @@ -117,6 +123,7 @@ pub enum Source { Reduce, Mfa, Conv, + Fill, } macro_rules! ops{ @@ -227,6 +234,7 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Fill => FILL, Source::Conv => CONV, Source::Mfa => panic!("Invalid lib"), } @@ -1562,9 +1570,36 @@ pub fn call_upsample_nearest_2d( Ok(()) } +#[inline] fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } +pub fn call_fill( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + elem_count: usize, + buffer: &Buffer, + value: D, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, elem_count as NSUInteger); + + set_params!(encoder, (buffer, value, elem_count)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.use_resource(buffer, metal::MTLResourceUsage::Write); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c955abca..c1c7b8ab 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -11,7 +11,7 @@ fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { fn new_buffer(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; - let ptr = data.as_ptr() as *const core::ffi::c_void; + let ptr = data.as_ptr() as *const c_void; let size = (data.len() * std::mem::size_of::()) as u64; device.new_buffer_with_data(ptr, size, options) } @@ -806,3 +806,48 @@ fn gemm() { vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] ); } + +fn run_fill( + elem_count: usize, + value: T, + kernel_name: &'static str, +) -> Vec { + let device = device(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let buffer = new_buffer(&device, &vec![0.0f32; elem_count]); + call_fill( + &device, + command_buffer, + &kernels, + kernel_name, + elem_count, + &buffer, + value, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&buffer, elem_count) +} + +#[test] +fn fill() { + fn assert_fill( + value: T, + name: &'static str, + ) { + for i in 0..4 { + assert_eq!(run_fill(8 ^ i, value, name), vec![value; 8 ^ i]); + } + } + assert_fill(123u8, "fill_u8"); + assert_fill(456u32, "fill_u32"); + assert_fill(789i64, "fill_i64"); + assert_fill(f16::from_f32(1.23), "fill_f16"); + assert_fill(bf16::from_f32(4.56), "fill_bf16"); + assert_fill(7.89f32, "fill_f32"); +}