diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f745342d..92c486d6 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,9 +4,7 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use half::f16; use metal; -use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::path::Path; @@ -115,7 +113,7 @@ impl MetalDevice { pub fn wait_until_completed(&self) { let command_buffers = self.command_buffers.try_write().unwrap(); let index = self.command_buffer_index.try_write().unwrap(); - let n = command_buffers.len(); + // let n = command_buffers.len(); // for i in 0..*index { // let command_buffer = &command_buffers[i]; // println!("Command {i} / {n}: {:?}", command_buffer.status()); @@ -216,39 +214,6 @@ impl MetalDevice { real } - pub fn new_matrix( - &self, - (b, m, n): (NSUInteger, NSUInteger, NSUInteger), - size: NSUInteger, - type_id: u32, - dtype: DType, - ) -> Result<(Matrix, Arc)> { - let elem_count = (b * m * n) as usize; - let buffer = self.new_buffer(elem_count, dtype, "matrix"); - let command_buffer = self.command_buffer(); - command_buffer.set_label("zeros_matmul"); - let blit = command_buffer.new_blit_command_encoder(); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.end_encoding(); - command_buffer.commit(); - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); - - let result_descriptor = - MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id); - let result_matrix = Matrix::init_with_buffer_descriptor(&buffer, 0, &result_descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - Ok((result_matrix, buffer)) - } - pub fn capture>(&self, path: P) -> Result<()> { let capture = metal::CaptureManager::shared(); let descriptor = metal::CaptureDescriptor::new(); @@ -266,22 +231,6 @@ impl MetalDevice { #[derive(Debug, Clone)] pub struct MetalStorage { buffer: Arc, - matrices: Arc< - RwLock< - HashMap< - ( - NSUInteger, - NSUInteger, - NSUInteger, - bool, - NSUInteger, - NSUInteger, - u32, - ), - Matrix, - >, - >, - >, device: MetalDevice, dtype: DType, } @@ -976,7 +925,6 @@ impl BackendStorage for MetalStorage { ) -> Result { crate::bail!("index_add metal") } - fn matmul( &self, rhs: &Self, @@ -985,104 +933,37 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { // Create descriptors - let (type_id, size) = match self.dtype { - DType::F32 => ( - metal::mps::MPS_FLOATBIT_ENCODING | 32, - core::mem::size_of::() as NSUInteger, - ), - DType::F16 => ( - metal::mps::MPS_FLOATBIT_ENCODING | 16, - core::mem::size_of::() as NSUInteger, - ), - dtype => todo!("Dtype for matmul {dtype:?} is not supported"), - }; - let lhs_stride = lhs_l.stride(); - let rhs_stride = rhs_l.stride(); - let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; - let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; - let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; - let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // The a tensor has dims batching, k, n (rhs) - let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { - false - } else if lhs_m1 == m && lhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? + let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } }; - let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { - false - } else if rhs_m1 == k && rhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - let b = b as NSUInteger; - let m = m as NSUInteger; - let n = n as NSUInteger; - let k = k as NSUInteger; - - let left_matrix = self.matrix( - (b, m, k), - transpose_left, - size, - lhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let right_matrix = rhs.matrix( - (b, k, n), - transpose_right, - size, - rhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let (result_matrix, out_buffer) = - self.device - .new_matrix((b, m, n), size, type_id, self.dtype)?; let command_buffer = self.device.command_buffer(); command_buffer.set_label("matmul"); - - let alpha = 1.0f64; - // let beta = f64::MIN; - let beta = 1.0; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - matrix_multiplication.set_batch_size(b); - matrix_multiplication.set_batch_start(0); - - // Encode kernel to command buffer - matrix_multiplication.encode_to_command_buffer( + candle_metal_kernels::call_gemm( + &self.device.device, &command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); + &self.device.kernels, + name, + (b, m, n, k), + &lhs_l.stride(), + lhs_l.start_offset(), + &self.buffer, + &rhs_l.stride(), + rhs_l.start_offset(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + // Create kernel command_buffer.commit(); - out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length())); - // println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer)); - Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) + + Ok(Self::new(buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { @@ -1133,46 +1014,16 @@ impl BackendStorage for MetalStorage { impl MetalStorage { pub fn new(buffer: Arc, device: MetalDevice, dtype: DType) -> Self { - let matrices = Arc::new(RwLock::new(HashMap::new())); Self { buffer, device, dtype, - matrices, } } pub fn buffer(&self) -> &Buffer { &self.buffer } - - fn matrix( - &self, - (b, m, n): (NSUInteger, NSUInteger, NSUInteger), - transpose: bool, - size: NSUInteger, - offset: NSUInteger, - type_id: u32, - ) -> Result { - let key = (b, m, n, transpose, size, offset, type_id); - - // let mut matrices = self.matrices.try_write().unwrap(); - // if let Some(matrix) = matrices.get(&key) { - // Ok(matrix.clone()) - // } else { - let descriptor = if transpose { - MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id) - } else { - MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id) - }; - let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - // matrices.insert(key, matrix.clone()); - Ok(matrix) - // } - } } impl BackendDevice for MetalDevice { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 237bd858..b80dcb79 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, Library, MTLSize, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; @@ -105,6 +106,7 @@ pub enum Source { Ternary, Cast, Reduce, + Mfa, } macro_rules! ops{ @@ -179,9 +181,8 @@ impl From> for MetalKernelError { } } -type KernelMap = HashMap<&'static str, T>; type Libraries = HashMap; -type Pipelines = KernelMap; +type Pipelines = HashMap<(&'static str, Option), ComputePipelineState>; #[derive(Debug, Default)] pub struct Kernels { @@ -208,9 +209,9 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Mfa => panic!("Invalid lib"), } } - pub fn load_library( &self, device: &Device, @@ -220,10 +221,20 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source_content = self.get_library_source(source); - let lib = device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + let lib = match source { + Source::Mfa => { + let source_data = MFA; + device + .new_library_with_data(source_data) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + source => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + }; libraries.insert(source, lib.clone()); Ok(lib) } @@ -234,19 +245,41 @@ impl Kernels { device: &Device, source: Source, name: &'static str, + constants: Option, ) -> Result { let func = self .load_library(device, source)? - .get_function(name, None) + .get_function(name, constants) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; Ok(func) - // let mut funcs = self.funcs.write()?; - // if let Some(func) = funcs.get(name) { - // Ok(func.clone()) - // } else { - // funcs.insert(name, func.clone()); - // Ok(func) - // } + } + + fn load_pipeline_with_constants( + &self, + device: &Device, + source: Source, + name: &'static str, + constants: Option, + ) -> Result { + let mut pipelines = self.pipelines.write()?; + let key = (name, constants); + if let Some(pipeline) = pipelines.get(&key) { + Ok(pipeline.clone()) + } else { + let (name, constants) = key; + let func = self.load_function( + device, + source, + name, + constants.as_ref().map(|c| c.function_constant_values()), + )?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert((name, constants), pipeline.clone()); + + Ok(pipeline) + } } pub fn load_pipeline( @@ -255,18 +288,7 @@ impl Kernels { source: Source, name: &'static str, ) -> Result { - let mut pipelines = self.pipelines.write()?; - if let Some(pipeline) = pipelines.get(name) { - Ok(pipeline.clone()) - } else { - let func = self.load_function(device, source, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; - pipelines.insert(name, pipeline.clone()); - - Ok(pipeline) - } + self.load_pipeline_with_constants(device, source, name, None) } } @@ -830,5 +852,249 @@ pub fn call_index_select( Ok(()) } +#[derive(Debug, PartialEq)] +pub enum Value { + USize(usize), + Bool(bool), + F32(f32), + U16(u16), +} + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + match self { + Value::F32(v) => v.to_bits().hash(state), + Value::USize(v) => v.hash(state), + Value::U16(v) => v.hash(state), + Value::Bool(v) => v.hash(state), + } + } +} + +impl Value { + fn data_type(&self) -> MTLDataType { + match self { + Value::USize(_) => MTLDataType::UInt, + Value::F32(_) => MTLDataType::Float, + Value::U16(_) => MTLDataType::UShort, + Value::Bool(_) => MTLDataType::Bool, + } + } +} + +/// Not true, good enough for our purposes. +impl Eq for Value {} + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ConstantValues(Vec<(usize, Value)>); + +impl ConstantValues { + pub fn new(values: Vec<(usize, Value)>) -> Self { + Self(values) + } + + fn function_constant_values(&self) -> FunctionConstantValues { + let f = FunctionConstantValues::new(); + for (index, value) in &self.0 { + let ty = value.data_type(); + match value { + Value::USize(v) => { + f.set_constant_value_at_index( + v as *const usize as *const c_void, + ty, + *index as u64, + ); + } + Value::F32(v) => { + f.set_constant_value_at_index( + v as *const f32 as *const c_void, + ty, + *index as u64, + ); + } + Value::U16(v) => { + f.set_constant_value_at_index( + v as *const u16 as *const c_void, + ty, + *index as u64, + ); + } + Value::Bool(v) => { + f.set_constant_value_at_index( + v as *const bool as *const c_void, + ty, + *index as u64, + ); + } + } + } + f + } +} + +#[allow(clippy::too_many_arguments)] +pub fn call_gemm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + todo!(); + // Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })? + }; + let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + todo!(); + // Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })? + }; + let d_trans = false; + let alpha = 1.0f32; + let beta = 0.0f32; + let batched = b > 1; + let fused_activation = false; + let fused_bias = false; + let m_simd = 16; + let n_simd = 16; + let k_simd = 16; + let m_splits = 2; + let n_splits = 2; + let constants = Some(ConstantValues::new(vec![ + (0, Value::USize(m)), + (1, Value::USize(n)), + (2, Value::USize(k)), + (10, Value::Bool(a_trans)), + (11, Value::Bool(b_trans)), + (13, Value::Bool(d_trans)), + (20, Value::F32(alpha)), + (21, Value::F32(beta)), + (100, Value::Bool(batched)), + (101, Value::Bool(fused_activation)), + // Garbage + (102, Value::Bool(false)), + (103, Value::Bool(false)), + (113, Value::Bool(false)), + (50_000, Value::Bool(false)), + // End garbage + (200, Value::U16(m_simd)), + (201, Value::U16(n_simd)), + (202, Value::U16(k_simd)), + (210, Value::U16(m_splits)), + (211, Value::U16(n_splits)), + (50_001, Value::Bool(fused_bias)), + ])); + // println!("Constants {constants:?}"); + let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + let m_group = m_simd * m_splits; + let n_group = n_simd * n_splits; + + let a_block_length = m_group * k_simd; + let b_block_length = k_simd * n_group; + + let mut block_elements = a_block_length + b_block_length; + if (m % 8 != 0) && (n % 8 != 0) { + let c_block_length = m_group * n_group; + block_elements = std::cmp::max(c_block_length, block_elements) + } + if fused_bias { + if d_trans { + block_elements = std::cmp::max(block_elements, m_group); + } else { + block_elements = std::cmp::max(block_elements, n_group); + } + } + // TODO adapt for f16 + let bytes = match name { + "sgemm" => 4, + "hgemm" => 2, + other => { + return Err(MetalKernelError::LoadLibraryError(format!( + "{other} is not a valid kernel for gemm" + ))); + } + }; + let block_bytes = block_elements * bytes; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + // println!("Threadgroup {block_bytes}"); + encoder.set_threadgroup_memory_length(0, block_bytes.into()); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(2, Some(output), 0); + // TODO Tensor D + + let grid_z = b; + if batched { + let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + let byte_stride_c = m * n * bytes as usize; + // TODO byte_stride_d + let byte_stride_d = 0; + + let mut buffer: Vec = Vec::with_capacity(b * 4); + for i in 0..b { + buffer.push((i * byte_stride_a) as u64); + buffer.push((i * byte_stride_b) as u64); + buffer.push((i * byte_stride_c) as u64); + buffer.push((i * byte_stride_d) as u64); + } + encoder.set_bytes( + 10, + (buffer.len() * core::mem::size_of::()) as NSUInteger, + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); + } + + let grid_size = MTLSize { + width: divide(n, n_group.into()), + height: divide(m, m_group.into()), + depth: grid_z as NSUInteger, + }; + let group_size = MTLSize { + width: 32 * (m_splits as u64) * (n_splits as u64), + height: 1, + depth: 1, + }; + // println!("grid size {grid_size:?} group size {group_size:?}"); + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.end_encoding(); + + Ok(()) +} + +fn divide(m: usize, b: usize) -> NSUInteger { + ((m + b - 1) / b) as NSUInteger +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib new file mode 100644 index 00000000..f5116ca6 Binary files /dev/null and b/candle-metal-kernels/src/libMetalFlashAttention.metallib differ