diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 12f56d50..7451c90d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -795,7 +795,6 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { // Create descriptors - let (type_id, size) = match self.dtype { DType::F32 => ( metal::mps::MPS_FLOATBIT_ENCODING | 32, diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 186f3209..af17a6d5 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0" [dependencies] metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } +metal-flash-attention = { path = "../../../metal-flash-attention" } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a0b852a4..9324c1a3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,9 +1,7 @@ -use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, Library, MTLSize, -}; +use metal::{Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; +use std::hash::Hash; use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); @@ -13,6 +11,7 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const MFA_LIB: &[u8] = include_bytes!("mfa.metallib"); fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; @@ -105,6 +104,7 @@ pub enum Source { Ternary, Cast, Reduce, + MetalFlashAttention, } macro_rules! ops{ @@ -179,7 +179,7 @@ impl From> for MetalKernelError { } } -type KernelMap = HashMap<&'static str, T>; +type KernelMap = HashMap; type Libraries = HashMap; type Pipelines = KernelMap; @@ -189,6 +189,22 @@ pub struct Kernels { pipelines: RwLock, } +enum LibraryDefinition { + Source(&'static str), + Data(&'static [u8]), +} + +impl From<&'static str> for LibraryDefinition { + fn from(s: &'static str) -> Self { + Self::Source(s) + } +} +impl From<&'static [u8]> for LibraryDefinition { + fn from(s: &'static [u8]) -> Self { + Self::Data(s) + } +} + impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); @@ -199,15 +215,16 @@ impl Kernels { } } - fn get_library_source(&self, source: Source) -> &'static str { + fn get_library_source(&self, source: Source) -> LibraryDefinition { match source { - Source::Affine => AFFINE, - Source::Unary => UNARY, - Source::Binary => BINARY, - Source::Ternary => TERNARY, - Source::Indexing => INDEXING, - Source::Cast => CAST, - Source::Reduce => REDUCE, + Source::Affine => AFFINE.into(), + Source::Unary => UNARY.into(), + Source::Binary => BINARY.into(), + Source::Ternary => TERNARY.into(), + Source::Indexing => INDEXING.into(), + Source::Cast => CAST.into(), + Source::Reduce => REDUCE.into(), + Source::MetalFlashAttention => MFA_LIB.into(), } } @@ -220,10 +237,15 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source_content = self.get_library_source(source); - let lib = device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + let lib = match self.get_library_source(source) { + LibraryDefinition::Source(source_content) => device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?, + LibraryDefinition::Data(data) => device + .new_library_with_data(data) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?, + }; + libraries.insert(source, lib.clone()); Ok(lib) } @@ -233,43 +255,154 @@ impl Kernels { &self, device: &Device, source: Source, - name: &'static str, + key: KernelKey, ) -> Result { let func = self .load_library(device, source)? - .get_function(name, None) + .get_function(key.name, key.constants.map(|c| c.create_function_constant_values())) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; Ok(func) - // let mut funcs = self.funcs.write()?; - // if let Some(func) = funcs.get(name) { - // Ok(func.clone()) - // } else { - // funcs.insert(name, func.clone()); - // Ok(func) - // } } - pub fn load_pipeline( + pub fn load_pipeline>( &self, device: &Device, source: Source, - name: &'static str, + key: T, ) -> Result { + let key: KernelKey = key.into(); let mut pipelines = self.pipelines.write()?; - if let Some(pipeline) = pipelines.get(name) { + if let Some(pipeline) = pipelines.get(&key) { Ok(pipeline.clone()) } else { - let func = self.load_function(device, source, name)?; + let func = self.load_function(device, source, key.clone())?; let pipeline = device .new_compute_pipeline_state_with_function(&func) .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; - pipelines.insert(name, pipeline.clone()); + pipelines.insert(key, pipeline.clone()); Ok(pipeline) } } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct KernelKey { + name: &'static str, + constants: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum ConstantValueId { + Index(NSUInteger), + Name(&'static str), +} + +trait MetalDType { + const MTL_DATA_TYPE: MTLDataType; +} +macro_rules! metal_dtype { + ($ty:ty, $mtl_data_type:ident) => { + impl MetalDType for $ty { + const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type; + } + } +} +metal_dtype!(f32, Float); +metal_dtype!(u32, UInt); +metal_dtype!(u16, UShort); +metal_dtype!(bool, Bool); + +#[derive(Debug, Clone, PartialEq)] +enum ConstantValue { + Float(f32), + Uint(u32), + UShort(u16), + Bool(bool), +} + +impl Hash for ConstantValue { + fn hash(&self, state: &mut H) { + use ConstantValue::*; + match self { + Float(_) => {}, // do nothing + Uint(v) => v.hash(state), + UShort(v) => v.hash(state), + Bool(v) => v.hash(state), + } + } +} + +impl Eq for ConstantValue {} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ConstantValues(Vec<(ConstantValueId, ConstantValue)>); + +macro_rules! add_indexed_constant { + ($fcv:expr, $value:expr, $ty:ty, $idx:expr) => { + $fcv.set_constant_value_at_index( + $value as *const $ty as *const c_void, + <$ty>::MTL_DATA_TYPE, + $idx, + ) + }; +} +macro_rules! add_named_constant { + ($fcv:expr, $value:expr, $ty:ty, $name:expr) => { + $fcv.set_constant_value_with_name( + $value as *const $ty as *const c_void, + <$ty>::MTL_DATA_TYPE, + $name, + ) + }; +} +impl ConstantValues { + fn create_function_constant_values(&self) -> FunctionConstantValues { + use ConstantValueId::*; + use ConstantValue::*; + let mut function_values = FunctionConstantValues::new(); + + for (id, value) in &self.0 { + match (id, value) { + (Index(index), Float(value)) => { + add_indexed_constant!(function_values, value, f32, *index); + } + (Index(index), Uint(value)) => { + add_indexed_constant!(function_values, value, u32, *index); + } + (Index(index), UShort(value)) => { + add_indexed_constant!(function_values, value, u16, *index); + } + (Index(index), Bool(value)) => { + add_indexed_constant!(function_values, value, bool, *index); + } + (Name(name), Float(value)) => { + add_named_constant!(function_values, value, f32, name); + } + (Name(name), Uint(value)) => { + add_named_constant!(function_values, value, u32, name); + } + (Name(name), UShort(value)) => { + add_named_constant!(function_values, value, u16, name); + } + (Name(name), Bool(value)) => { + add_named_constant!(function_values, value, bool, name); + } + } + } + function_values + } +} + +impl From<&'static str> for KernelKey { + fn from(name: &'static str) -> Self { + Self { + name, + constants: None, + } + } +} + #[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous( device: &Device, @@ -706,5 +839,45 @@ pub fn call_index_select( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_mfa_gemm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + strides: &[usize], + offset: usize, + output: &Buffer, + output_offset: usize, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::MetalFlashAttention, name)?; + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + set_params!( + encoder, + ( + length, + num_dims, + shape, + strides, + (input, offset), + (output, output_offset) + ) + ); + + let width: usize = shape.iter().product(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/mfa.metallib b/candle-metal-kernels/src/mfa.metallib new file mode 100644 index 00000000..dafd1856 Binary files /dev/null and b/candle-metal-kernels/src/mfa.metallib differ