use metal::{ Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; pub mod mlx_gemm; pub mod sort; pub mod utils; pub use mlx_gemm::{call_mlx_gemm, GemmDType}; pub use sort::{call_arg_sort, call_mlx_arg_sort}; pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; const AFFINE: &str = include_str!("affine.metal"); const BINARY: &str = include_str!("binary.metal"); const CAST: &str = include_str!("cast.metal"); const CONV: &str = include_str!("conv.metal"); const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); const MLX_SORT: &str = include_str!("mlx_sort.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); const RANDOM: &str = include_str!("random.metal"); const REDUCE: &str = include_str!("reduce.metal"); const SORT: &str = include_str!("sort.metal"); const TERNARY: &str = include_str!("ternary.metal"); const UNARY: &str = include_str!("unary.metal"); const SDPA: &str = include_str!("scaled_dot_product_attention.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum DType { BF16, F16, F32, I64, U32, U8, } impl DType { fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, Self::U32 => 4, Self::I64 => 8, Self::BF16 => 2, Self::F16 => 2, Self::F32 => 4, } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, Binary, Cast, Conv, Fill, Gemm, Indexing, MlxSort, Quantized, Random, Reduce, Sort, Ternary, Unary, Sdpa, } pub mod copy2d { pub struct Kernel(pub &'static str); pub const FLOAT: Kernel = Kernel("copy2d_f32"); pub const HALF: Kernel = Kernel("copy2d_f16"); pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); pub const I64: Kernel = Kernel("copy2d_i64"); pub const U32: Kernel = Kernel("copy2d_u32"); pub const U8: Kernel = Kernel("copy2d_u8"); } macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } )+ pub mod copy { use super::Kernel; pub const FLOAT: Kernel = Kernel("copy_f32"); pub const HALF: Kernel = Kernel("copy_f16"); pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const I64: Kernel = Kernel("copy_i64"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } } pub mod contiguous_tiled { pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); } )+ pub mod copy { use super::Kernel; pub const FLOAT: Kernel = Kernel("copy_f32_tiled"); pub const HALF: Kernel = Kernel("copy_f16_tiled"); pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); pub const I64: Kernel = Kernel("copy_i64_tiled"); pub const U32: Kernel = Kernel("copy_u32_tiled"); pub const U8: Kernel = Kernel("copy_u8_tiled"); } } pub mod strided { pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } )+ pub mod copy { use super::Kernel; pub const FLOAT: Kernel = Kernel("copy_f32_strided"); pub const HALF: Kernel = Kernel("copy_f16_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const I64: Kernel = Kernel("copy_i64_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } } }; } pub mod unary { ops!( cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, tanh, recip, silu, sign, sigmoid, const_set ); } pub mod binary { ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); } #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { #[error("Could not lock kernel map: {0}")] LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), #[error("Error while loading function: {0}")] LoadFunctionError(String), #[error("Failed to create compute function")] FailedToCreateComputeFunction, #[error("Failed to create pipeline")] FailedToCreatePipeline(String), #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")] MatMulNonContiguous { lhs_stride: Vec, rhs_stride: Vec, mnk: (usize, usize, usize), }, #[error("Sdpa {variation} head size was {got}, expectd {expected:?}")] SdpaHeadSizeMismatch { variation: &'static str, got: usize, expected: Vec, }, #[error("Sdpa {variation} got dtype {got:?}")] SdpaHeadDTypeMismatch { variation: &'static str, got: SdpaDType, }, } impl From> for MetalKernelError { fn from(e: std::sync::PoisonError) -> Self { Self::LockError(e.to_string()) } } #[derive(Debug, Clone)] pub enum KernelName { Ref(&'static str), Value(String), } impl AsRef for KernelName { fn as_ref(&self) -> &str { match self { Self::Ref(r) => r, Self::Value(v) => v.as_str(), } } } impl std::hash::Hash for KernelName { fn hash(&self, state: &mut H) { match self { Self::Ref(r) => r.hash(state), Self::Value(v) => v.hash(state), } } } impl PartialEq for KernelName { fn eq(&self, other: &Self) -> bool { let v1: &str = self.as_ref(); let v2: &str = other.as_ref(); v1 == v2 } } impl Eq for KernelName {} impl From<&'static str> for KernelName { fn from(value: &'static str) -> Self { Self::Ref(value) } } impl From for KernelName { fn from(value: String) -> Self { Self::Value(value) } } type Libraries = HashMap; type Pipelines = HashMap<(KernelName, Option), ComputePipelineState>; #[derive(Debug)] pub struct Kernels { libraries: RwLock, pipelines: RwLock, } impl Default for Kernels { fn default() -> Self { Self::new() } } impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); let pipelines = RwLock::new(Pipelines::new()); Self { libraries, pipelines, } } fn get_library_source(&self, source: Source) -> &'static str { match source { Source::Affine => AFFINE, Source::Binary => BINARY, Source::Cast => CAST, Source::Conv => CONV, Source::Fill => FILL, Source::Gemm => MLX_GEMM, Source::Indexing => INDEXING, Source::MlxSort => MLX_SORT, Source::Quantized => QUANTIZED, Source::Random => RANDOM, Source::Reduce => REDUCE, Source::Sort => SORT, Source::Ternary => TERNARY, Source::Unary => UNARY, Source::Sdpa => SDPA, } } /// Load the give library from its [`source`]. /// If this has been previously loaded it will just fetch it from cache. pub fn load_library( &self, device: &Device, source: Source, ) -> Result { let mut libraries = self.libraries.write()?; if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { let lib = { let source_content = self.get_library_source(source); device .new_library_with_source(source_content, &CompileOptions::new()) .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? }; libraries.insert(source, lib.clone()); Ok(lib) } } fn load_function( &self, device: &Device, source: Source, name: &str, constants: Option, ) -> Result { let func = self .load_library(device, source)? .get_function(name, constants) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; Ok(func) } /// Load the give pipeline /// loads the library from source, then gets the function [`name`] from /// that source fn load_pipeline_with_constants( &self, device: &Device, source: Source, name: impl Into, constants: Option, ) -> Result { let mut pipelines = self.pipelines.write()?; let key = (name.into(), constants); if let Some(pipeline) = pipelines.get(&key) { Ok(pipeline.clone()) } else { let (name, constants) = key; let func = self.load_function( device, source, name.as_ref(), constants.as_ref().map(|c| c.function_constant_values()), )?; let pipeline = device .new_compute_pipeline_state_with_function(&func) .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; pipelines.insert((name, constants), pipeline.clone()); Ok(pipeline) } } /// Load the give pipeline /// loads the library from source, then gets the function [`name`] from /// that source (without constants) pub fn load_pipeline( &self, device: &Device, source: Source, name: impl Into, ) -> Result { self.load_pipeline_with_constants(device, source, name, None) } } #[allow(clippy::too_many_arguments)] pub fn call_copy2d( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: copy2d::Kernel, input: &Buffer, output: &Buffer, d1: usize, d2: usize, src_s: usize, dst_s: usize, src_o_in_bytes: usize, dst_o_in_bytes: usize, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( d1 as i64, d2 as i64, src_s as i64, dst_s as i64, (input, src_o_in_bytes), (output, dst_o_in_bytes) ) ); let grid_dims = MTLSize { width: d1 as u64, height: d2 as u64, depth: 1, }; let group_dims = get_block_dims(d1 as u64, d2 as u64, 1); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_threads(grid_dims, group_dims); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_const_set_contiguous_tiled( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: unary::contiguous_tiled::Kernel, length: usize, input: impl EncoderParam, output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let tile_size = 2; let tiles = length.div_ceil(tile_size); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_const_set_contiguous( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: unary::contiguous::Kernel, length: usize, input: impl EncoderParam, output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, &output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_const_set_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: unary::strided::Kernel, shape: &[usize], input: impl EncoderParam, strides: &[usize], output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let length: usize = shape.iter().product(); let num_dims: usize = shape.len(); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, num_dims, shape, strides, input, &output)); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous_tiled( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: unary::contiguous_tiled::Kernel, length: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let tile_size = 2; let tiles = length.div_ceil(tile_size); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: unary::contiguous::Kernel, length: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: unary::strided::Kernel, shape: &[usize], input: BufferOffset, strides: &[usize], output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let length: usize = shape.iter().product(); let num_dims: usize = shape.len(); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, num_dims, shape, strides, &input, &output)); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_binary_contiguous( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: binary::contiguous::Kernel, length: usize, left: BufferOffset, right: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &left, &right, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_binary_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: binary::strided::Kernel, shape: &[usize], left_input: BufferOffset, left_strides: &[usize], right_input: BufferOffset, right_strides: &[usize], output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let width: usize = shape.iter().product(); let length: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( length, num_dims, shape, left_strides, right_strides, &left_input, &right_input, output ) ); encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_cast_contiguous( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_cast_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, shape: &[usize], input: BufferOffset, input_strides: &[usize], output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); set_params!( encoder, (length, shape.len(), shape, input_strides, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, shape: &[usize], out_length: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let length = shape.iter().product::(); let num_dims = shape.len(); let work_per_threadgroup = length / out_length; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( length, num_dims, shape, work_per_threadgroup, &input, output ) ); let thread_group_count = MTLSize { width: out_length as u64, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), (work_per_threadgroup / 2).next_power_of_two() as NSUInteger, ); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_reduce_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, shape: &[usize], strides: &[usize], out_length: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let length: usize = shape.iter().product(); let num_dims = shape.len(); let work_per_threadgroup = length / out_length; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( length, num_dims, shape, strides, work_per_threadgroup, &input, output ) ); let thread_group_count = MTLSize { width: out_length as u64, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), (work_per_threadgroup / 2).next_power_of_two() as NSUInteger, ); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, elements: usize, input: &Buffer, input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let work_per_threadgroup = elements; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (length, work_per_threadgroup, (input, input_offset), output) ); let out_length = length / work_per_threadgroup; let thread_group_count = MTLSize { width: out_length as NSUInteger, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), (work_per_threadgroup / 2).next_power_of_two() as NSUInteger, ); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rms_norm( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, elements_to_sum: usize, eps: f32, input: &Buffer, input_offset: usize, alpha: &Buffer, alpha_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( length, elements_to_sum, (input, input_offset), output, (alpha, alpha_offset), eps ) ); let out_length = length / elements_to_sum; let thread_group_count = MTLSize { width: out_length as u64, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), elements_to_sum as u64, ) .next_power_of_two(); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_layer_norm( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, elements_to_sum: usize, eps: f32, input: &Buffer, input_offset: usize, alpha: &Buffer, alpha_offset: usize, beta: &Buffer, beta_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( length, elements_to_sum, (input, input_offset), output, (alpha, alpha_offset), (beta, beta_offset), eps ) ); let out_length = length / elements_to_sum; let thread_group_count = MTLSize { width: out_length as u64, height: 1, depth: 1, }; let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), elements_to_sum as u64, ) .next_power_of_two(); let thread_group_size = MTLSize { width, height: 1, depth: 1, }; encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rope_i( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, bh: usize, td: usize, stride_b: usize, src: &Buffer, src_offset: usize, cos: &Buffer, cos_offset: usize, sin: &Buffer, sin_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( bh, td, stride_b, (src, src_offset), (cos, cos_offset), (sin, sin_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); encoder.use_resource(src, metal::MTLResourceUsage::Read); encoder.use_resource(cos, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rope_thd( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, b: usize, t: usize, h: usize, d: usize, stride_b: usize, src: &Buffer, src_offset: usize, cos: &Buffer, cos_offset: usize, sin: &Buffer, sin_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( b, t, h, d, stride_b, (src, src_offset), (cos, cos_offset), (sin, sin_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2); encoder.use_resource(src, metal::MTLResourceUsage::Read); encoder.use_resource(cos, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rope( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, bh: usize, td: usize, d: usize, stride_b: usize, src: &Buffer, src_offset: usize, cos: &Buffer, cos_offset: usize, sin: &Buffer, sin_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( bh, td, d, stride_b, (src, src_offset), (cos, cos_offset), (sin, sin_offset), output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); encoder.use_resource(src, metal::MTLResourceUsage::Read); encoder.use_resource(cos, metal::MTLResourceUsage::Read); encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, size: usize, input: BufferOffset, output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_affine_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], input: BufferOffset, input_stride: &[usize], output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( size, shape.len(), shape, input_stride, mul, add, &input, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_powf( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, size: usize, input: BufferOffset, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_powf_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], input: BufferOffset, input_stride: &[usize], output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (size, shape.len(), shape, input_stride, mul, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_elu( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, size: usize, input: BufferOffset, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_elu_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], input: BufferOffset, input_stride: &[usize], output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (size, shape.len(), shape, input_stride, mul, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_where_cond_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], cond: BufferOffset, cond_stride: &[usize], left: BufferOffset, left_stride: &[usize], right: BufferOffset, right_stride: &[usize], output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); let rank = shape.len(); set_params!( encoder, ( size, rank, shape, cond_stride, left_stride, right_stride, &cond, &left, &right, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_index_select( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], ids_size: usize, dim: usize, contiguous: bool, src_dims: &[usize], src_strides: &[usize], input: BufferOffset, ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, &input, &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_gather( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], ids_size: usize, dim: usize, input: BufferOffset, ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, ids_size, &input, &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_scatter( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, src_shape: &[usize], dst_shape: &[usize], dim: usize, input: BufferOffset, ids: BufferOffset, output: BufferOffset, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); let right_size: usize = src_shape[dim + 1..].iter().product(); let src_dim_size = src_shape[dim]; let dst_el = left_size * right_size; let dst_dim_size = dst_shape[dim]; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, dst_dim_size, &input, &ids, &output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_index_add( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, src_shape: &[usize], dst_shape: &[usize], ids_shape: &[usize], dim: usize, input: BufferOffset, ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); let right_size: usize = src_shape[dim + 1..].iter().product(); let src_dim_size = src_shape[dim]; let dst_el = left_size * right_size; let dst_dim_size = dst_shape[dim]; let ids_dim_size = ids_shape[0]; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, &input, &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[derive(Debug, PartialEq)] pub enum Value { USize(usize), Bool(bool), F32(f32), U16(u16), } impl std::hash::Hash for Value { fn hash(&self, state: &mut H) { match self { Value::F32(v) => v.to_bits().hash(state), Value::USize(v) => v.hash(state), Value::U16(v) => v.hash(state), Value::Bool(v) => v.hash(state), } } } impl Value { fn data_type(&self) -> MTLDataType { match self { Value::USize(_) => MTLDataType::UInt, Value::F32(_) => MTLDataType::Float, Value::U16(_) => MTLDataType::UShort, Value::Bool(_) => MTLDataType::Bool, } } } /// Not true, good enough for our purposes. impl Eq for Value {} #[derive(Debug, Eq, PartialEq, Hash)] struct ConstantValues(Vec<(usize, Value)>); impl ConstantValues { pub fn new(values: Vec<(usize, Value)>) -> Self { Self(values) } fn function_constant_values(&self) -> FunctionConstantValues { let f = FunctionConstantValues::new(); for (index, value) in &self.0 { let ty = value.data_type(); match value { Value::USize(v) => { f.set_constant_value_at_index( v as *const usize as *const c_void, ty, *index as u64, ); } Value::F32(v) => { f.set_constant_value_at_index( v as *const f32 as *const c_void, ty, *index as u64, ); } Value::U16(v) => { f.set_constant_value_at_index( v as *const u16 as *const c_void, ty, *index as u64, ); } Value::Bool(v) => { f.set_constant_value_at_index( v as *const bool as *const c_void, ty, *index as u64, ); } } } f } } #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum SdpaDType { BF16, F16, F32, } /// SDPA full is supported when: /// - q head dim == 64, 128 /// - no mask /// - q heads == kv heads /// - final type != bf16 (TODO maybe just template this kernel too?) /// - q,k,v are contiguous #[allow(clippy::too_many_arguments)] pub fn call_sdpa_full( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, q_offset: usize, q_shape: &[usize], q_buffer: &Buffer, k_offset: usize, k_buffer: &Buffer, v_offset: usize, v_buffer: &Buffer, output: &Buffer, alpha: f32, softcapping: f32, itype: SdpaDType, ) -> Result<(), MetalKernelError> { #[derive(Debug)] #[repr(C)] struct MLXFastAttentionParams { m: i32, n: i32, k: i32, ldq: i32, // ldq == ldo ldk: i32, ldv: i32, lds: i32, ldo: i32, tiles_n: i32, tiles_m: i32, batch_stride_q: i32, batch_stride_k: i32, batch_stride_v: i32, batch_stride_o: i32, swizzle_log: i32, gemm_n_iterations_aligned: i32, gemm_k_iterations_aligned: i32, gemm_sv_m_block_iterations: i32, batch_ndim: i32, alpha: f32, softcapping: f32, } let bk = q_shape.last().unwrap(); const BN: usize = 16; const BM: usize = 16; const WM: usize = 2; const WN: usize = 2; let name = match (bk, itype) { (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half", (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half", (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half", (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half", (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half", (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float", (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float", (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float", (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float", (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float", (other, SdpaDType::F16 | SdpaDType::F32) => { return Err(MetalKernelError::SdpaHeadSizeMismatch { variation: "full", got: *other, expected: vec![32, 64, 96, 128, 256], }) } (_, SdpaDType::BF16) => { return Err(MetalKernelError::SdpaHeadDTypeMismatch { variation: "full", got: SdpaDType::BF16, }) } }; let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, seq, hidden) let qseq = q_shape[q_shape.len() - 2]; let m = q_shape[q_shape.len() - 2]; let n = m; let k = q_shape[q_shape.len() - 1]; let bs_out = q_shape[0] * q_shape[1]; let batch_shape = [q_shape[0] * q_shape[1]]; let dk = q_shape[q_shape.len() - 1]; let ldq = dk; let ldk = dk; let ldv = dk; let lds = BN; let ldo = dk; let tn = 1; let tm = m.div_ceil(BM); let b_stride_q = dk * qseq; let b_stride_k = dk * qseq; let b_stride_v = dk * qseq; let b_stride_o = dk * qseq; let swizzle_log = 0; let gemm_n_iterations_aligned = n.div_ceil(BN); let gemm_k_iterations_aligned = k.div_ceil(*bk); let gemm_sv_m_block_iterations = m.div_ceil(BM); let batch_ndim = batch_shape.len(); let alpha = if softcapping != 1. { alpha / softcapping } else { alpha }; let params = MLXFastAttentionParams { m: m as i32, n: n as i32, k: k as i32, ldq: ldq as i32, ldk: ldk as i32, ldv: ldv as i32, lds: lds as i32, ldo: ldo as i32, tiles_n: tn, tiles_m: tm as i32, batch_stride_q: b_stride_q as i32, batch_stride_k: b_stride_k as i32, batch_stride_v: b_stride_v as i32, batch_stride_o: b_stride_o as i32, swizzle_log, gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32, gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32, gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32, batch_ndim: batch_ndim as i32, alpha, softcapping, }; let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; impl EncoderParam for MLXFastAttentionParams { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, core::mem::size_of::() as u64, &data as *const MLXFastAttentionParams as *const c_void, ); } } set_params!( encoder, ( (q_buffer, q_offset), (k_buffer, k_offset), (v_buffer, v_offset), output, params, &batch_shape[..], &batch_strides[..] ) ); let grid_dims = MTLSize { width: 1, height: tm as u64, depth: bs_out as u64, }; let group_dims = MTLSize { width: 32, height: WM as u64, depth: WN as u64, }; encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); Ok(()) } /// SDPA full is supported when: /// - q head dim == 64, 96, 128 /// - no mask /// - q,k,v are contiguous #[allow(clippy::too_many_arguments)] pub fn call_sdpa_vector( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, q_offset: usize, q_shape: &[usize], q_buffer: &Buffer, k_offset: usize, k_shape: &[usize], k_stride: &[usize], k_buffer: &Buffer, v_offset: usize, v_stride: &[usize], v_buffer: &Buffer, output: &Buffer, alpha: f32, softcapping: f32, itype: SdpaDType, ) -> Result<(), MetalKernelError> { let bk = q_shape.last().unwrap(); let gqa_factor = (q_shape[1] / k_shape[1]) as i32; let n = k_shape[2] as i32; let b = (q_shape[0] * q_shape[1]) as i32; let kstride = k_stride[1]; let vstride = v_stride[1]; let name = match (bk, itype) { (32, SdpaDType::F16) => "sdpa_vector_float16_t_32", (64, SdpaDType::F16) => "sdpa_vector_float16_t_64", (96, SdpaDType::F16) => "sdpa_vector_float16_t_96", (128, SdpaDType::F16) => "sdpa_vector_float16_t_128", (256, SdpaDType::F16) => "sdpa_vector_float16_t_256", (32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32", (64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64", (96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96", (128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128", (256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256", (32, SdpaDType::F32) => "sdpa_vector_float_32", (64, SdpaDType::F32) => "sdpa_vector_float_64", (96, SdpaDType::F32) => "sdpa_vector_float_96", (128, SdpaDType::F32) => "sdpa_vector_float_128", (256, SdpaDType::F32) => "sdpa_vector_float_256", (other, _) => { return Err(MetalKernelError::SdpaHeadSizeMismatch { variation: "vector", got: *other, expected: vec![32, 64, 96, 128, 256], }) } }; let alpha = if softcapping != 1. { alpha / softcapping } else { alpha }; let constants = Some(ConstantValues::new(vec![( 20, Value::Bool(/* sdpa_vector_has_mask */ false), )])); let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, kv_seq, hidden) set_params!( encoder, ( (q_buffer, q_offset), (k_buffer, k_offset), (v_buffer, v_offset), output, gqa_factor, n, kstride, vstride, alpha, softcapping ) ); let grid_dims = MTLSize { width: 1, height: b as u64, depth: 1_u64, }; let group_dims = MTLSize { width: 1024, height: 1, depth: 1, }; encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); Ok(()) } pub const SDPA_2PASS_BLOCKS: usize = 32; /// SDPA vector 2pass is supported when: /// - q head dim == 64, 96, 128 /// - no mask /// - q,k,v are contiguous #[allow(clippy::too_many_arguments)] pub fn call_sdpa_vector_2pass( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, q_offset: usize, q_shape: &[usize], q_buffer: &Buffer, k_offset: usize, k_shape: &[usize], k_stride: &[usize], k_buffer: &Buffer, v_offset: usize, v_stride: &[usize], v_buffer: &Buffer, output: &Buffer, intermediate: &Buffer, sums: &Buffer, maxs: &Buffer, alpha: f32, softcapping: f32, itype: SdpaDType, ) -> Result<(), MetalKernelError> { let bk = q_shape.last().unwrap(); // First pass { let name_pass1 = match (bk, itype) { (32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32", (64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64", (96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96", (128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128", (256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256", (32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32", (64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64", (96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96", (128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128", (256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256", (32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32", (64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64", (96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96", (128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128", (256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256", (other, _) => { return Err(MetalKernelError::SdpaHeadSizeMismatch { variation: "vector_2pass_1", got: *other, expected: vec![32, 64, 96, 128, 256], }) } }; let gqa_factor = (q_shape[1] / k_shape[1]) as i32; let n = k_shape[2] as i32; let b = (q_shape[0] * q_shape[1]) as i32; let kstride = k_stride[1]; let vstride = v_stride[1]; let alpha = if softcapping != 1. { alpha / softcapping } else { alpha }; let constants = Some(ConstantValues::new(vec![( 20, Value::Bool(/* sdpa_vector_has_mask */ false), )])); let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, kv_seq, hidden) set_params!( encoder, ( (q_buffer, q_offset), (k_buffer, k_offset), (v_buffer, v_offset), intermediate, sums, maxs, gqa_factor, n, kstride, vstride, alpha, softcapping ) ); let grid_dims = MTLSize { width: 1, height: b as u64, depth: SDPA_2PASS_BLOCKS as u64, }; let group_dims = MTLSize { width: 8 * 32, height: 1, depth: 1, }; encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(intermediate, metal::MTLResourceUsage::Write); encoder.use_resource(sums, metal::MTLResourceUsage::Write); encoder.use_resource(maxs, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); } // Final pass { let name_pass2 = match (bk, itype) { (32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32", (64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64", (96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96", (128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128", (256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256", (32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32", (64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64", (96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96", (128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128", (256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256", (32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32", (64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64", (96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96", (128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128", (256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256", (other, _) => { return Err(MetalKernelError::SdpaHeadSizeMismatch { variation: "vector_2pass_2", got: *other, expected: vec![32, 64, 96, 128, 256], }) } }; let b = (q_shape[0] * q_shape[1]) as i32; let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, kv_seq, hidden) set_params!(encoder, (intermediate, sums, maxs, output)); let grid_dims = MTLSize { width: 1, height: b as u64, depth: 1, }; let group_dims = MTLSize { width: 1024, height: 1, depth: 1, }; encoder.use_resource(intermediate, metal::MTLResourceUsage::Write); encoder.use_resource(sums, metal::MTLResourceUsage::Write); encoder.use_resource(maxs, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_dims, group_dims); } Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_im2col1d_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], strides: &[usize], (k_size, stride, padding, dilation): (usize, usize, usize, usize), input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1; let dst_el = shape[0] * l_out * shape[1] * k_size; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_col2im1d( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], k_size: usize, stride: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let l_in = shape[1]; let c_out = shape[2]; let l_out = (l_in - 1) * stride + k_size; let dst_el = shape[0] * c_out * l_out; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (dst_el, l_out, l_in, c_out, k_size, stride, &input, output) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_im2col_strided( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], strides: &[usize], (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let h = shape[2]; let w = shape[3]; let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1; let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1; let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input, output ) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_upsample_nearest_2d( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], strides: &[usize], out_w: usize, out_h: usize, input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; let dst_el = out_w * out_h * shape[0] * shape[1]; let scale_w = shape[2] as f32 / out_w as f32; let scale_h = shape[3] as f32 / out_h as f32; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (out_w, out_h, scale_w, scale_h, shape, strides, &input, output) ); encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_random_uniform( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, min: f32, max: f32, length: usize, seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { if min >= max { return Err(MetalKernelError::LoadLibraryError( "min must be less than max".to_string(), )); } let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let odd = (length % 2 != 0) as usize; let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, min, max, seed, buffer)); encoder.use_resource( seed, metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_random_normal( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, mean: f32, stddev: f32, length: usize, seed: &Buffer, buffer: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let odd = (length % 2 != 0) as usize; let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, mean, stddev, seed, buffer)); encoder.use_resource( seed, metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write, ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[derive(Debug, Clone, Copy)] pub enum GgmlDType { Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, Q2K, Q3K, Q4K, Q5K, Q6K, Q8K, F16, F32, } #[allow(clippy::too_many_arguments)] pub fn call_quantized_matmul_mv_t( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, dtype: GgmlDType, (b, m, n, k): (usize, usize, usize, usize), lhs: &Buffer, lhs_offset: usize, rhs: &Buffer, dst_offset: usize, dst: &Buffer, ) -> Result<(), MetalKernelError> { // Everything is in reverse let ne00 = k as i64; let ne01 = n as i64; let ne02 = b as i64; let ne03 = 1i64; let nb00 = 0i64; let nb01 = 0i64; let nb02 = 0i64; let ne10 = k as i64; let ne11 = m as i64; let ne12 = b as i64; let ne13 = 1i64; let nb10 = 0i64; let nb11 = 0i64; let nb12 = 0i64; let ne0 = n as i64; let ne1 = m as i64; let r2: u32 = (ne12 / ne02) as u32; let r3: u32 = (ne13 / ne03) as u32; let (nth0, nth1, align) = match dtype { GgmlDType::Q4_0 | GgmlDType::Q4_1 | GgmlDType::Q5_0 | GgmlDType::Q5_1 | GgmlDType::Q8_0 | GgmlDType::Q8_1 => { let nth0 = 8; let nth1 = 8; let align = 8; (nth0, nth1, align) } GgmlDType::Q2K => { // Fixing a bug in Metal for GGML // https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576 let nth0 = 2; let nth1 = 32; let align = 4; (nth0, nth1, align) } GgmlDType::Q4K => { let nth0 = 4; let nth1 = 8; let align = 4; (nth0, nth1, align) } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; let align = 4; (nth0, nth1, align) } GgmlDType::Q6K => { let nth0 = 2; let nth1 = 32; let align = 2; (nth0, nth1, align) } GgmlDType::F16 | GgmlDType::Q8K => { // Original implem uses rows let nth0 = 32; let nth1 = 1; let align = 8; (nth0, nth1, align) } GgmlDType::F32 => { let nth0 = 32; let nth1 = 1; let align = 8; (nth0, nth1, align) } }; let thread_groups_count = MTLSize { width: divide(ne01 as usize, align), height: ne11 as u64, depth: (ne12 * ne13) as u64, }; let threads_per_threadgroup = MTLSize { width: nth0, height: nth1, depth: 1, }; let name = match dtype { GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32", GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32", GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32", GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32", GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32", GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32", GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32", GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32", GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32", GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32", GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32", GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( rhs, (lhs, lhs_offset), (dst, dst_offset), ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3 ) ); encoder.use_resource(lhs, metal::MTLResourceUsage::Read); encoder.use_resource(rhs, metal::MTLResourceUsage::Read); encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); Ok(()) } fn divide(m: usize, b: usize) -> NSUInteger { m.div_ceil(b) as NSUInteger } #[allow(clippy::too_many_arguments)] pub fn call_pool2d( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], strides: &[usize], out_w: usize, out_h: usize, w_k: usize, h_k: usize, w_stride: usize, h_stride: usize, input: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { let dst_el = out_w * out_h * shape[0] * shape[1]; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, (w_k, h_k, w_stride, h_stride, shape, strides, input, output) ); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_conv_transpose1d( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, dilation: usize, stride: usize, padding: usize, out_padding: usize, c_out: usize, l_out: usize, b_size: usize, src_shape: &[usize], src_strides: &[usize], kernel_shape: &[usize], kernel_strides: &[usize], input: &Buffer, input_offset: usize, kernel: &Buffer, kernel_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let dst_el = c_out * l_out * b_size; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( l_out, stride, padding, out_padding, dilation, src_shape, src_strides, kernel_shape, kernel_strides, (input, input_offset), (kernel, kernel_offset), output ) ); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } pub struct CallConvTranspose2dCfg<'a> { pub dilation: usize, pub stride: usize, pub padding: usize, pub output_padding: usize, pub c_out: usize, pub out_w: usize, pub out_h: usize, pub b_size: usize, pub input_dims: &'a [usize], pub input_stride: &'a [usize], pub kernel_dims: &'a [usize], pub kernel_stride: &'a [usize], pub input_offset: usize, pub kernel_offset: usize, } #[allow(clippy::too_many_arguments)] pub fn call_conv_transpose2d( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, cfg: CallConvTranspose2dCfg, input: &Buffer, kernel: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( cfg.out_w, cfg.out_h, cfg.stride, cfg.padding, cfg.output_padding, cfg.dilation, cfg.input_dims, cfg.input_stride, cfg.kernel_dims, cfg.kernel_stride, (input, cfg.input_offset), (kernel, cfg.kernel_offset), output ) ); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } pub fn call_const_fill( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, length: usize, output: &Buffer, v: impl EncoderParam, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (output, v, length)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) } #[cfg(test)] mod tests;