From 7ff17d92b3b580ddb3f78023b2507d879b7a5c38 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 6 Nov 2023 23:12:12 +0100 Subject: [PATCH] Finished the unary - Added proper kernel type check (through modules + macro) - split contiguous and strided into 2 different kernels - Verified on long range + strided values. --- candle-metal-kernels/src/lib.rs | 332 +++++++++++++++++++++------ candle-metal-kernels/src/unary.metal | 56 ++--- 2 files changed, 288 insertions(+), 100 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f877b5e3..01dd309b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,12 +1,48 @@ -use metal::{Buffer, CompileOptions, Device, Function, Library}; +use metal::{ + Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, + MTLSize, +}; use once_cell::sync::Lazy; use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; -pub const AFFINE: &str = include_str!("affine.metal"); -pub const INDEXING: &str = include_str!("indexing.metal"); -pub const UNARY: &str = include_str!("unary.metal"); +const AFFINE: &str = include_str!("affine.metal"); +const INDEXING: &str = include_str!("indexing.metal"); +const UNARY: &str = include_str!("unary.metal"); + +macro_rules! unary{ + ($($name:ident),+) => { + + pub mod contiguous { + pub struct Kernel(pub &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + } + )+ + } + + pub mod strided { + pub struct Kernel(pub &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + } + )+ + } + }; +} + +pub mod unary { + unary!(cos, sin, exp); +} static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { let mut l = HashMap::new(); @@ -104,24 +140,112 @@ impl Kernels { Ok(func) } } - - pub fn call_unary( - &self, - device: &Device, - library_name: &'static str, - name: &'static str, - input: &Buffer, - output: &mut Buffer, - length: usize, - ) -> Result<(), MetalKernelError> { - let func = self.load_function(device, library_name, name)?; - call_unary(&func, input, output, length); - Ok(()) - } } -fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usize) { - todo!("Call unary"); +pub fn call_unary_contiguous( + device: &Device, + command_buffer: &CommandBuffer, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, "unary", kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(&input), 0); + encoder.set_buffer(2, Some(&output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} +pub fn call_unary_strided( + device: &Device, + command_buffer: &CommandBuffer, + kernels: &Kernels, + name: unary::strided::Kernel, + input: &Buffer, + shape: &[usize], + strides: &[usize], + offset: usize, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, "unary", name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len() as usize; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); + encoder.set_bytes( + 2, + (shape.len() * std::mem::size_of::()) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + (strides.len() * std::mem::size_of::()) as u64, + strides.as_ptr() as *const c_void, + ); + encoder.set_bytes(4, std::mem::size_of::() as u64, void_ptr(&offset)); + + encoder.set_buffer(5, Some(&input), 0); + encoder.set_buffer(6, Some(&output), 0); + + let width = output.length(); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) } pub fn void_ptr(v: &T) -> *const c_void { @@ -132,9 +256,7 @@ pub fn void_ptr(v: &T) -> *const c_void { mod tests { use super::*; use half::f16; - use metal::{ - CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger, - }; + use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; use std::mem; fn device() -> Device { @@ -151,58 +273,63 @@ mod tests { v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() } - fn run_cos(v: &[T], name: &str) -> Vec { + fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { let device = device(); - let options = MTLResourceOptions::StorageModeManaged; + let kernels = Kernels::new(); let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); + let command_buffer = command_queue.new_owned_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; let input = device.new_buffer_with_data( v.as_ptr() as *const core::ffi::c_void, (v.len() * core::mem::size_of::()) as u64, options, ); - let output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); - let library = device - .new_library_with_source(UNARY, &CompileOptions::new()) - .expect("Failed to load unary library"); - let func = library.get_function(&format!("cos_{name}"), None).unwrap(); - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); + call_unary_contiguous( + &device, + &command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - - let dim: u32 = v.len() as u32; - // let num_dims: u32 = 1; - // let info = [v.len() as u32, 1]; - - let encoder = command_buffer.new_compute_command_encoder(); - encoder.set_compute_pipeline_state(&pipeline); - - encoder.set_bytes(0, 4, void_ptr(&dim)); - - encoder.set_buffer(1, Some(&input), 0); - encoder.set_buffer(2, Some(&output), 0); - - let width = v.len() as NSUInteger; - - let thread_group_count = MTLSize { - width, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: pipeline.max_total_threads_per_threadgroup(), - height: 1, - depth: 1, - }; - - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + fn run_strided( + v: &[T], + kernel: unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + offset: usize, + ) -> Vec { + let device = device(); + let options = MTLResourceOptions::StorageModeManaged; + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_owned_command_buffer(); + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + (v.len() * core::mem::size_of::()) as u64, + options, + ); + let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); + let kernels = Kernels::new(); + call_unary_strided( + &device, + &command_buffer, + &kernels, + kernel, + &input, + shape, + strides, + offset, + &mut output, + ) + .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); output.read_to_vec::(v.len()) @@ -211,10 +338,77 @@ mod tests { #[test] fn cos_f32() { let v = vec![1.0f32, 2.0, 3.0]; - let results = run_cos(&v, "float"); + let results = run(&v, unary::contiguous::cos::FLOAT); let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + + #[test] + fn cos_f32_strided() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + // Shape = [6], strides = [1]; + let shape = vec![6]; + let strides = vec![1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Contiguous + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Transposed + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![1, 3]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Very large + let v = vec![1.0f32; 10_000]; + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); } #[test] @@ -360,7 +554,7 @@ mod tests { .iter() .map(|v| f16::from_f32(*v)) .collect(); - let results = run_cos(&v, "half"); + let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]); assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index fd9011ba..7349ce97 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,28 +1,19 @@ #include -# -METAL_FUNC bool is_contiguous( - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides -) { - size_t acc = 1; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - if (acc != strides[dim_idx]) { - return false; - } - acc *= dims[dim_idx]; - } - return true; -} + +struct Info{ + device size_t &num_dims; + device size_t *dims; + device size_t *strides; +}; METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, constant size_t *dims, - constant size_t *strides + constant size_t *strides, + constant size_t &offset ) { - uint strided_i = 0; + uint strided_i = offset; for (uint d = 0; d < num_dims; d++) { uint dim_idx = num_dims - 1 - d; strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; @@ -40,37 +31,40 @@ kernel void FN_NAME( \ device const TYPENAME *input, \ device TYPENAME *output, \ uint threadgroup_size [[threads_per_threadgroup]], \ - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ uint thread_index [[thread_index_in_threadgroup]] \ ) { \ - const uint i = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ - if (i > dim){ \ - return; \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = FN(input[i]); \ } \ - output[i] = FN(input[i]); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ constant size_t &num_dims, \ - constant size_t *info, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &offset, \ device const TYPENAME *input, \ device TYPENAME *output, \ uint threadgroup_size [[threads_per_threadgroup]], \ - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ uint thread_index [[thread_index_in_threadgroup]] \ ) { \ - constant size_t *dims = info; \ - constant size_t *strides = info + num_dims; \ - const uint start = thread_index + (threadgroup_position_in_grid * threadgroup_size); \ - const uint stop = min(thread_index + (threadgroup_position_in_grid * threadgroup_size), (uint) dim); \ - for (size_t i = start; i < stop; i++) { \ - output[i] = FN(input[get_strided_index(i, num_dims, dims, strides)]); \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = FN(input[get_strided_index(i, num_dims, dims, strides, offset)]); \ } \ } UNARY(cos, float, cos_float, cos_float_strided); UNARY(cos, half, cos_half, cos_half_strided); +UNARY(sin, float, sin_float, sin_float_strided); +UNARY(sin, half, sin_half, sin_half_strided); #if __METAL_VERSION__ >= 310 UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided); +UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided); #endif