use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); const UNARY: &str = include_str!("unary.metal"); const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); /// Most kernels apply similarly across the tensors /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the /// actual total buffer length). /// Then kernels can just do their op on their single point in the buffer. fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); let count = (size + width - 1) / width; let thread_group_count = MTLSize { width: count, height: 1, depth: 1, }; let thread_group_size = MTLSize { width, height: 1, depth: 1, }; (thread_group_count, thread_group_size) } fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { ($type:ty) => { impl EncoderParam for $type { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, core::mem::size_of::<$type>() as u64, &data as *const $type as *const c_void, ); } } }; } primitive!(usize); primitive!(u32); primitive!(f32); impl EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, core::mem::size_of_val(data) as u64, data.as_ptr() as *const c_void, ); } } impl EncoderParam for &Buffer { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data), 0); } } impl EncoderParam for (&Buffer, usize) { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data.0), data.1 as u64); } } impl EncoderParam for &mut Buffer { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data), 0); } } impl EncoderParam for (&mut Buffer, usize) { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data.0), data.1 as u64); } } macro_rules! set_params { ($encoder:ident, ($($param:expr),+)) => ( let mut _index = 0; $( set_param($encoder, _index, $param); _index += 1; )* ); } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, Indexing, Unary, Binary, Ternary, Cast, Reduce, Mfa, } macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); } )+ pub mod copy { use super::Kernel; pub const FLOAT: Kernel = Kernel("copy_f32"); pub const HALF: Kernel = Kernel("copy_f16"); pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } } pub mod strided { pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); } )+ pub mod copy { use super::Kernel; pub const FLOAT: Kernel = Kernel("copy_f32_strided"); pub const HALF: Kernel = Kernel("copy_f16_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } } }; } pub mod unary { ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); } #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { #[error("Could not lock kernel map: {0}")] LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), #[error("Error while loading function: {0:?}")] LoadFunctionError(String), #[error("Failed to create compute function")] FailedToCreateComputeFunction, #[error("Failed to create pipeline")] FailedToCreatePipeline(String), #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")] MatMulNonContiguous { lhs_stride: Vec, rhs_stride: Vec, mnk: (usize, usize, usize), }, } impl From> for MetalKernelError { fn from(e: std::sync::PoisonError) -> Self { Self::LockError(e.to_string()) } } type Libraries = HashMap; type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; #[derive(Debug)] pub struct Kernels { libraries: RwLock, pipelines: RwLock, fence: metal::Fence, } impl Kernels { pub fn new(fence: metal::Fence) -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, fence, } } fn get_library_source(&self, source: Source) -> &'static str { match source { Source::Affine => AFFINE, Source::Unary => UNARY, Source::Binary => BINARY, Source::Ternary => TERNARY, Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, Source::Mfa => panic!("Invalid lib"), } } /// Load the give library from its [`source`]. /// If this has been previously loaded it will just fetch it from cache. pub fn load_library( &self, device: &Device, source: Source, ) -> Result { let mut libraries = self.libraries.write()?; if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { let lib = match source { Source::Mfa => { let source_data = MFA; device.new_library_with_data(source_data).map_err(|e| { MetalKernelError::LoadLibraryError(format!( "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" )) })? } source => { let source_content = self.get_library_source(source); device .new_library_with_source(source_content, &CompileOptions::new()) .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? } }; libraries.insert(source, lib.clone()); Ok(lib) } } fn load_function( &self, device: &Device, source: Source, name: &'static str, constants: Option, ) -> Result { let func = self .load_library(device, source)? .get_function(name, constants) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; Ok(func) } /// Load the give pipeline /// loads the library from source, then gets the function [`name`] from /// that source fn load_pipeline_with_constants( &self, device: &Device, source: Source, name: &'static str, constants: Option, ) -> Result { let mut pipelines = self.pipelines.write()?; let key = (name, constants); if let Some(pipeline) = pipelines.get(&key) { Ok(pipeline.clone()) } else { let (name, constants) = key; let func = self.load_function( device, source, name, constants.as_ref().map(|c| c.function_constant_values()), )?; let pipeline = device .new_compute_pipeline_state_with_function(&func) .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; pipelines.insert((name, constants), pipeline.clone()); Ok(pipeline) } } /// Load the give pipeline /// loads the library from source, then gets the function [`name`] from /// that source (without constants) pub fn load_pipeline( &self, device: &Device, source: Source, name: &'static str, ) -> Result { self.load_pipeline_with_constants(device, source, name, None) } } #[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: unary::strided::Kernel, shape: &[usize], input: &Buffer, strides: &[usize], offset: usize, output: &Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); set_params!( encoder, ( length, num_dims, shape, strides, (input, offset), (output, output_offset) ) ); let width: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_binary_contiguous( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: binary::contiguous::Kernel, length: usize, left: &Buffer, right: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(left, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_binary_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: binary::strided::Kernel, shape: &[usize], left_input: &Buffer, left_strides: &[usize], left_offset: usize, right_input: &Buffer, right_strides: &[usize], right_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); set_params!( encoder, ( length, num_dims, shape, left_strides, right_strides, (left_input, left_offset), (right_input, right_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.use_resource(left_input, metal::MTLResourceUsage::Read); encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_cast_contiguous( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: &'static str, length: usize, input: &Buffer, input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, (input, input_offset), output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_cast_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: &'static str, shape: &[usize], input: &Buffer, input_strides: &[usize], input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); set_params!( encoder, ( length, shape.len(), shape, input_strides, (input, input_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: &'static str, length: usize, out_length: usize, input: &Buffer, input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (length, elements_to_sum, (input, input_offset), output) ); let thread_group_count = MTLSize { width: out_length as u64, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), (elements_to_sum as u64 + 2 - 1) / 2, ) .next_power_of_two(); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } pub fn call_reduce_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: &'static str, shape: &[usize], strides: &[usize], out_length: usize, input: &Buffer, input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let length: usize = shape.iter().product(); let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( shape.len(), shape, strides, elements_to_sum, (input, input_offset), output ) ); let thread_group_count = MTLSize { width: out_length as u64, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), elements_to_sum as u64, ) .next_power_of_two(); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, kernel_name: &'static str, length: usize, elements_to_sum: usize, input: &Buffer, input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (length, elements_to_sum, (input, input_offset), output) ); let out_length = length / elements_to_sum; let thread_group_count = MTLSize { width: out_length as u64, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), elements_to_sum as u64, ) .next_power_of_two(); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, size: usize, input: &Buffer, output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_affine_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, shape: &[usize], input: &Buffer, input_stride: &[usize], input_offset: usize, output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( size, shape.len(), shape, input_stride, mul, add, (input, input_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_powf( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, size: usize, input: &Buffer, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_powf_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, shape: &[usize], input: &Buffer, input_stride: &[usize], input_offset: usize, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( size, shape.len(), shape, input_stride, mul, (input, input_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_elu( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, size: usize, input: &Buffer, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_elu_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, shape: &[usize], input: &Buffer, input_stride: &[usize], input_offset: usize, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( size, shape.len(), shape, input_stride, mul, (input, input_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, shape: &[usize], cond: &Buffer, (cond_stride, cond_offset): (&[usize], usize), left: &Buffer, (left_stride, left_offset): (&[usize], usize), right: &Buffer, (right_stride, right_offset): (&[usize], usize), output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); let rank = shape.len(); set_params!( encoder, ( size, rank, shape, cond_stride, left_stride, right_stride, (cond, cond_offset), (left, left_offset), (right, right_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(cond, metal::MTLResourceUsage::Read); encoder.use_resource(left, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_index_select( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, shape: &[usize], ids_size: usize, dim: usize, input: &Buffer, ids: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, ids_size, input, ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_gather( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, shape: &[usize], ids_size: usize, dim: usize, input: &Buffer, input_offset: usize, ids: &Buffer, ids_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, ids_size, (input, input_offset), (ids, ids_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } pub fn call_scatter_add( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, src_shape: &[usize], dst_shape: &[usize], dim: usize, input: &Buffer, input_offset: usize, ids: &Buffer, ids_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); let right_size: usize = src_shape[dim + 1..].iter().product(); let src_dim_size = src_shape[dim]; let dst_el = left_size * right_size; let dst_dim_size = dst_shape[dim]; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, dst_dim_size, (input, input_offset), (ids, ids_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } pub fn call_index_add( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, src_shape: &[usize], dst_shape: &[usize], ids_shape: &[usize], dim: usize, input: &Buffer, input_offset: usize, ids: &Buffer, ids_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); let right_size: usize = src_shape[dim + 1..].iter().product(); let src_dim_size = src_shape[dim]; let dst_el = left_size * right_size; let dst_dim_size = dst_shape[dim]; let ids_dim_size = ids_shape[0]; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, (input, input_offset), (ids, ids_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[derive(Debug, PartialEq)] pub enum Value { USize(usize), Bool(bool), F32(f32), U16(u16), } impl std::hash::Hash for Value { fn hash(&self, state: &mut H) { match self { Value::F32(v) => v.to_bits().hash(state), Value::USize(v) => v.hash(state), Value::U16(v) => v.hash(state), Value::Bool(v) => v.hash(state), } } } impl Value { fn data_type(&self) -> MTLDataType { match self { Value::USize(_) => MTLDataType::UInt, Value::F32(_) => MTLDataType::Float, Value::U16(_) => MTLDataType::UShort, Value::Bool(_) => MTLDataType::Bool, } } } /// Not true, good enough for our purposes. impl Eq for Value {} #[derive(Debug, Eq, PartialEq, Hash)] struct ConstantValues(Vec<(usize, Value)>); impl ConstantValues { pub fn new(values: Vec<(usize, Value)>) -> Self { Self(values) } fn function_constant_values(&self) -> FunctionConstantValues { let f = FunctionConstantValues::new(); for (index, value) in &self.0 { let ty = value.data_type(); match value { Value::USize(v) => { f.set_constant_value_at_index( v as *const usize as *const c_void, ty, *index as u64, ); } Value::F32(v) => { f.set_constant_value_at_index( v as *const f32 as *const c_void, ty, *index as u64, ); } Value::U16(v) => { f.set_constant_value_at_index( v as *const u16 as *const c_void, ty, *index as u64, ); } Value::Bool(v) => { f.set_constant_value_at_index( v as *const bool as *const c_void, ty, *index as u64, ); } } } f } } #[allow(clippy::too_many_arguments)] pub fn call_gemm( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, (b, m, n, k): (usize, usize, usize, usize), lhs_stride: &[usize], lhs_offset: usize, lhs_buffer: &Buffer, rhs_stride: &[usize], rhs_offset: usize, rhs_buffer: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { assert!(rhs_stride.len() >= 2); assert!(lhs_stride.len() >= 2); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; let a_trans = if lhs_m1 == 1 && lhs_m2 == k { false } else if lhs_m1 == m && lhs_m2 == 1 { true } else { return Err(MetalKernelError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), })?; }; let b_trans = if rhs_m1 == 1 && rhs_m2 == n { false } else if rhs_m1 == k && rhs_m2 == 1 { true } else { return Err(MetalKernelError::MatMulNonContiguous { lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), })?; }; let d_trans = false; let alpha = 1.0f32; let beta = 0.0f32; let batched = b > 1; let fused_activation = false; let fused_bias = false; let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { let m_simd = 16; let n_simd = 8; let k_simd = 64; let m_splits = 1; let n_splits = 1; (m_simd, n_simd, k_simd, m_splits, n_splits) } else { let m_simd = 40; let n_simd = 40; let k_simd = 8; let m_splits = 1; let n_splits = 1; (m_simd, n_simd, k_simd, m_splits, n_splits) }; let constants = Some(ConstantValues::new(vec![ (0, Value::USize(m)), (1, Value::USize(n)), (2, Value::USize(k)), (10, Value::Bool(a_trans)), (11, Value::Bool(b_trans)), (13, Value::Bool(d_trans)), (20, Value::F32(alpha)), (21, Value::F32(beta)), (100, Value::Bool(batched)), (101, Value::Bool(fused_activation)), // Garbage (102, Value::Bool(false)), (103, Value::Bool(false)), (113, Value::Bool(false)), (50_000, Value::Bool(false)), // End garbage (200, Value::U16(m_simd)), (201, Value::U16(n_simd)), (202, Value::U16(k_simd)), (210, Value::U16(m_splits)), (211, Value::U16(n_splits)), (50_001, Value::Bool(fused_bias)), ])); let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; let m_group = m_simd * m_splits; let n_group = n_simd * n_splits; let a_block_length = m_group * k_simd; let b_block_length = k_simd * n_group; let mut block_elements = a_block_length + b_block_length; if (m % 8 != 0) && (n % 8 != 0) { let c_block_length = m_group * n_group; block_elements = std::cmp::max(c_block_length, block_elements) } if fused_bias { if d_trans { block_elements = std::cmp::max(block_elements, m_group); } else { block_elements = std::cmp::max(block_elements, n_group); } } let bytes = match name { "sgemm" => 4, "hgemm" => 2, other => { return Err(MetalKernelError::LoadLibraryError(format!( "{other} is not a valid kernel for gemm" ))); } }; let block_bytes = block_elements * bytes; let encoder = command_buffer.new_compute_command_encoder(); encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); encoder.set_buffer(2, Some(output), 0); // TODO Tensor D let grid_z = b; if batched { let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; let byte_stride_c = m * n * bytes as usize; // TODO byte_stride_d let byte_stride_d = 0; let mut buffer: Vec = Vec::with_capacity(b * 4); for i in 0..b { buffer.push((i * byte_stride_a) as u64); buffer.push((i * byte_stride_b) as u64); buffer.push((i * byte_stride_c) as u64); buffer.push((i * byte_stride_d) as u64); } encoder.set_bytes( 10, (buffer.len() * core::mem::size_of::()) as NSUInteger, buffer.as_ptr() as *const NSUInteger as *const c_void, ); } let grid_size = MTLSize { width: divide(n, n_group.into()), height: divide(m, m_group.into()), depth: grid_z as NSUInteger, }; let group_size = MTLSize { width: 32 * (m_splits as u64) * (n_splits as u64), height: 1, depth: 1, }; // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } #[cfg(test)] mod tests;