diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a0b852a4..d298483b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, Library, MTLSize, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; @@ -105,6 +106,7 @@ pub enum Source { Ternary, Cast, Reduce, + Mfa, } macro_rules! ops{ @@ -179,9 +181,88 @@ impl From> for MetalKernelError { } } -type KernelMap = HashMap<&'static str, T>; +#[derive(Debug, PartialEq)] +pub enum Value { + U32(u32), + Bool(bool), + F32(f32), + U16(u16), +} + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + match self { + Value::F32(v) => v.to_bits().hash(state), + Value::U32(v) => v.hash(state), + Value::U16(v) => v.hash(state), + Value::Bool(v) => v.hash(state), + } + } +} + +impl Value { + fn data_type(&self) -> MTLDataType { + match self { + Value::U32(_) => MTLDataType::UInt, + Value::F32(_) => MTLDataType::Float, + Value::U16(_) => MTLDataType::UShort, + Value::Bool(_) => MTLDataType::Bool, + } + } +} + +/// Not true, good enough for our purposes. +impl Eq for Value {} + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ConstantValues(Vec<(usize, Value)>); + +impl ConstantValues { + pub fn new(values: Vec<(usize, Value)>) -> Self { + Self(values) + } + + fn function_constant_values(&self) -> FunctionConstantValues { + let f = FunctionConstantValues::new(); + for (index, value) in &self.0 { + let ty = value.data_type(); + match value { + Value::U32(v) => { + f.set_constant_value_at_index( + v as *const u32 as *const c_void, + ty, + *index as u64, + ); + } + Value::F32(v) => { + f.set_constant_value_at_index( + v as *const f32 as *const c_void, + ty, + *index as u64, + ); + } + Value::U16(v) => { + f.set_constant_value_at_index( + v as *const u16 as *const c_void, + ty, + *index as u64, + ); + } + Value::Bool(v) => { + f.set_constant_value_at_index( + v as *const bool as *const c_void, + ty, + *index as u64, + ); + } + } + } + f + } +} + type Libraries = HashMap; -type Pipelines = KernelMap; +type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; #[derive(Debug, Default)] pub struct Kernels { @@ -208,6 +289,7 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Mfa => unimplemented!("Mfa is not a source"), } } @@ -220,10 +302,20 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source_content = self.get_library_source(source); - let lib = device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + let lib = match source { + Source::Mfa => { + let source_data = MFA; + device + .new_library_with_data(source_data) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + source => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + }; libraries.insert(source, lib.clone()); Ok(lib) } @@ -234,19 +326,41 @@ impl Kernels { device: &Device, source: Source, name: &'static str, + constants: Option, ) -> Result { let func = self .load_library(device, source)? - .get_function(name, None) + .get_function(name, constants) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; Ok(func) - // let mut funcs = self.funcs.write()?; - // if let Some(func) = funcs.get(name) { - // Ok(func.clone()) - // } else { - // funcs.insert(name, func.clone()); - // Ok(func) - // } + } + + fn load_pipeline_with_constants( + &self, + device: &Device, + source: Source, + name: &'static str, + constants: Option, + ) -> Result { + let mut pipelines = self.pipelines.write()?; + let key = (name, constants); + if let Some(pipeline) = pipelines.get(&key) { + Ok(pipeline.clone()) + } else { + let (name, constants) = key; + let func = self.load_function( + device, + source, + name, + constants.as_ref().map(|c| c.function_constant_values()), + )?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert((name, constants), pipeline.clone()); + + Ok(pipeline) + } } pub fn load_pipeline( @@ -255,18 +369,7 @@ impl Kernels { source: Source, name: &'static str, ) -> Result { - let mut pipelines = self.pipelines.write()?; - if let Some(pipeline) = pipelines.get(name) { - Ok(pipeline.clone()) - } else { - let func = self.load_function(device, source, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; - pipelines.insert(name, pipeline.clone()); - - Ok(pipeline) - } + self.load_pipeline_with_constants(device, source, name, None) } } @@ -706,5 +809,130 @@ pub fn call_index_select( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_gemm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let a_trans = false; + let b_trans = false; + let d_trans = false; + let alpha = 1.0; + let beta = 0.0; + let batched = b > 1; + let fused_activation = false; + let fused_bias = false; + let m_simd = 16; + let n_simd = 16; + let k_simd = 16; + let m_splits = 2; + let n_splits = 2; + let constants = Some(ConstantValues::new(vec![ + (0, Value::U32(m as u32)), + (1, Value::U32(n as u32)), + (2, Value::U32(k as u32)), + (10, Value::Bool(a_trans)), + (11, Value::Bool(b_trans)), + (13, Value::Bool(d_trans)), + (20, Value::F32(alpha)), + (21, Value::F32(beta)), + (100, Value::Bool(batched)), + (101, Value::Bool(fused_activation)), + (200, Value::U16(m_simd)), + (201, Value::U16(n_simd)), + (202, Value::U16(k_simd)), + (210, Value::U16(m_splits)), + (211, Value::U16(n_splits)), + (50_001, Value::Bool(fused_bias)), + ])); + let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + let m_group = m_simd * m_splits; + let n_group = n_simd * n_splits; + + let a_block_length = m_group * k_simd; + let b_block_length = k_simd * n_group; + + let mut block_elements = a_block_length + b_block_length; + if (m % 8 != 0) && (n % 8 != 0) { + let c_block_length = m_group * n_group; + block_elements = std::cmp::max(c_block_length, block_elements) + } + if fused_bias { + if d_trans { + block_elements = std::cmp::max(block_elements, m_group); + } else { + block_elements = std::cmp::max(block_elements, n_group); + } + } + // TODO adapt for f16 + let bytes = match name { + "sgemm" => 4, + "hgemm" => 2, + other => { + return Err(MetalKernelError::LoadLibraryError(format!( + "{other} is not a valid kernel for gemm" + ))); + } + }; + let block_bytes = block_elements * bytes; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(block_bytes.into(), 0); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(2, Some(output), 0); + // TODO Tensor D + + let grid_z = b; + let byte_stride_a: usize = *lhs_stride.get(lhs_stride.len() - 2).unwrap_or(&0); + let byte_stride_b = *rhs_stride.get(rhs_stride.len() - 2).unwrap_or(&0); + let byte_stride_c = m * n; + // TODO byte_stride_d + let byte_stride_d = 1; + + let mut buffer = Vec::with_capacity(b * 4); + for i in 0..b { + buffer.push(i * byte_stride_a); + buffer.push(i * byte_stride_b); + buffer.push(i * byte_stride_c); + buffer.push(i * byte_stride_d); + } + encoder.set_bytes( + 10, + buffer.len() as NSUInteger, + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); + + let grid_size = MTLSize { + width: divide(n, n_group.into()), + height: divide(m, m_group.into()), + depth: grid_z as NSUInteger, + }; + let group_size = MTLSize { + width: 32 * (m_splits as u64) * (n_splits as u64), + height: 1, + depth: 1, + }; + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.end_encoding(); + + Ok(()) +} + +fn divide(m: usize, b: usize) -> NSUInteger { + ((m + b - 1) / b) as NSUInteger +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib new file mode 100644 index 00000000..8c8ce692 Binary files /dev/null and b/candle-metal-kernels/src/libMetalFlashAttention.metallib differ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 59f54fa9..5805206b 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -725,3 +725,66 @@ fn where_cond() { ); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } + +fn run_gemm( + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: Vec, + rhs: &[T], + rhs_stride: Vec, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_gemm( + &device, + command_buffer, + &kernels, + "sgemm", + (b, m, n, k), + &lhs_stride, + 0, + &lhs, + &rhs_stride, + 0, + &rhs, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(length) +} + +#[test] +fn gemm() { + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); +}