From ce0783d9ffe66e8fa187d56e04d08b4224378544 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 10 Dec 2023 13:11:53 +0100 Subject: [PATCH] Stash for debugging --- candle-core/src/metal_backend.rs | 67 +++---- candle-metal-kernels/src/lib.rs | 298 +++++++++++++++++++++++++++---- 2 files changed, 285 insertions(+), 80 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 7451c90d..3f18a290 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -795,14 +795,16 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { // Create descriptors - let (type_id, size) = match self.dtype { + let (type_id, size, name) = match self.dtype { DType::F32 => ( metal::mps::MPS_FLOATBIT_ENCODING | 32, core::mem::size_of::() as NSUInteger, + "sgemm", ), DType::F16 => ( metal::mps::MPS_FLOATBIT_ENCODING | 16, core::mem::size_of::() as NSUInteger, + "hgemm", ), dtype => todo!("Dtype for matmul {dtype:?} is not supported"), }; @@ -836,60 +838,37 @@ impl BackendStorage for MetalStorage { mnk: (m, n, k), })? }; - let b = b as NSUInteger; - let m = m as NSUInteger; - let n = n as NSUInteger; - let k = k as NSUInteger; - let left_matrix = self.matrix( - (b, m, k), - transpose_left, - size, - lhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let right_matrix = rhs.matrix( - (b, k, n), - transpose_right, - size, - rhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let (result_matrix, out_buffer) = - self.device - .new_matrix((b, m, n), size, type_id, self.dtype)?; + let result_buffer = self.device.new_buffer(b * m * n, self.dtype); let command_buffer = self.device.command_buffer(); - let alpha = 1.0f64; - let beta = 0.0f64; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, + command_buffer.set_label("mfa gemm"); + + candle_metal_kernels::call_mfa_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + &self.buffer, + lhs_l.shape().dims(), + &rhs.buffer, + rhs_l.shape().dims(), + &result_buffer, + (b, m, n, k), transpose_left, transpose_right, - m, - n, - k, - alpha, - beta, ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; + .map_err(MetalError::from)?; - // Encode kernel to command buffer - matrix_multiplication.encode_to_command_buffer( - &command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); - command_buffer.set_label("matmul"); drop(command_buffer); self.device.commit(); - Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) + Ok(Self::new( + self.buffer.clone(), + self.device.clone(), + self.dtype(), + )) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 9324c1a3..03633918 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,12 @@ -use metal::{Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger}; -use std::collections::HashMap; +use metal::{ + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceUsage, MTLSize, + NSUInteger, +}; +use std::collections::{BTreeMap, HashMap}; use std::ffi::c_void; use std::hash::Hash; +use std::io::{stdout, Write}; use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); @@ -259,7 +264,10 @@ impl Kernels { ) -> Result { let func = self .load_library(device, source)? - .get_function(key.name, key.constants.map(|c| c.create_function_constant_values())) + .get_function( + key.name, + key.constants.map(|c| c.create_function_constant_values()), + ) .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; Ok(func) } @@ -292,7 +300,21 @@ struct KernelKey { constants: Option, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +impl KernelKey { + fn new(name: &'static str) -> Self { + Self { + name, + constants: None, + } + } + + fn with_constants(mut self, constants: ConstantValues) -> Self { + self.constants = Some(constants); + self + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] enum ConstantValueId { Index(NSUInteger), Name(&'static str), @@ -306,7 +328,7 @@ macro_rules! metal_dtype { impl MetalDType for $ty { const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type; } - } + }; } metal_dtype!(f32, Float); metal_dtype!(u32, UInt); @@ -314,18 +336,18 @@ metal_dtype!(u16, UShort); metal_dtype!(bool, Bool); #[derive(Debug, Clone, PartialEq)] -enum ConstantValue { +enum ConstantValueType { Float(f32), Uint(u32), UShort(u16), Bool(bool), } -impl Hash for ConstantValue { +impl Hash for ConstantValueType { fn hash(&self, state: &mut H) { - use ConstantValue::*; + use ConstantValueType::*; match self { - Float(_) => {}, // do nothing + Float(v) => v.to_bits().hash(state), Uint(v) => v.hash(state), UShort(v) => v.hash(state), Bool(v) => v.hash(state), @@ -333,10 +355,10 @@ impl Hash for ConstantValue { } } -impl Eq for ConstantValue {} +impl Eq for ConstantValueType {} -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct ConstantValues(Vec<(ConstantValueId, ConstantValue)>); +#[derive(Debug, Clone, PartialEq, Eq)] +struct ConstantValues(BTreeMap); macro_rules! add_indexed_constant { ($fcv:expr, $value:expr, $ty:ty, $idx:expr) => { @@ -356,14 +378,33 @@ macro_rules! add_named_constant { ) }; } + +impl Hash for ConstantValues { + fn hash(&self, state: &mut H) { + for (id, value) in &self.0 { + id.hash(state); + value.hash(state); + } + } +} + impl ConstantValues { + fn new() -> Self { + Self(BTreeMap::new()) + } + + fn set(mut self, id: impl Into, value: impl Into) -> Self { + self.0.insert(id.into(), value.into()); + self + } + fn create_function_constant_values(&self) -> FunctionConstantValues { use ConstantValueId::*; - use ConstantValue::*; + use ConstantValueType::*; let mut function_values = FunctionConstantValues::new(); for (id, value) in &self.0 { - match (id, value) { + match (&id, &value) { (Index(index), Float(value)) => { add_indexed_constant!(function_values, value, f32, *index); } @@ -839,42 +880,227 @@ pub fn call_index_select( Ok(()) } +impl From for ConstantValueId { + fn from(idx: NSUInteger) -> Self { + Self::Index(idx) + } +} + +impl From for ConstantValueId { + fn from(idx: usize) -> Self { + ConstantValueId::from(idx as NSUInteger) + } +} + +impl From for ConstantValueId { + fn from(idx: i32) -> Self { + ConstantValueId::from(idx as NSUInteger) + } +} + +impl From<&'static str> for ConstantValueId { + fn from(name: &'static str) -> Self { + Self::Name(name) + } +} + +macro_rules! to_constant_value { + ($ty:ty, $constant_value_type:ident) => { + to_constant_value!($ty, $ty, $constant_value_type); + }; + ($ty:ty, $via:ty, $constant_value_type:ident) => { + impl From<$ty> for ConstantValueType { + fn from(v: $ty) -> Self { + Self::$constant_value_type(v as $via) + } + } + }; +} +to_constant_value!(f32, Float); +to_constant_value!(u32, Uint); +to_constant_value!(usize, u32, Uint); +to_constant_value!(u16, UShort); +to_constant_value!(bool, Bool); + +struct MFAGemmConfig { + m: usize, + k: usize, + n: usize, + transpose_left: bool, + transpose_right: bool, + batched: bool, + m_simd: u16, + n_simd: u16, + k_simd: u16, + m_splits: u16, + n_splits: u16, + m_group: u16, + n_group: u16, +} + +impl From for ConstantValues { + fn from(conf: MFAGemmConfig) -> Self { + ConstantValues::new() + .set(0, conf.m) + .set(1, conf.k) + .set(2, conf.n) + .set(10, conf.transpose_left) + .set(11, conf.transpose_right) + .set(12, false) + .set(20, 1.0) + .set(21, 0.0) + .set(100, conf.batched) + .set(101, false) + .set(50001, false) + .set(200, conf.m_simd) + .set(201, conf.n_simd) + .set(202, conf.k_simd) + .set(210, conf.m_splits) + .set(211, conf.n_splits) + // garbage + .set(102, false) + .set(103, false) + .set(113, false) + .set(50000, false) + } +} + #[allow(clippy::too_many_arguments)] pub fn call_mfa_gemm( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, name: &'static str, - shape: &[usize], - input: &Buffer, - strides: &[usize], - offset: usize, + lhs: &Buffer, + lhs_dims: &[usize], + rhs: &Buffer, + rhs_dims: &[usize], output: &Buffer, - output_offset: usize, + (b, m, n, k): (usize, usize, usize, usize), + transpose_left: bool, + transpose_right: bool, ) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::MetalFlashAttention, name)?; + let batched = b > 1; + + let mut c_elements = m * n; + if batched { + c_elements *= 2; + } + + let is_half = name == "hgemm"; + let is_float = name == "sgemm"; + + let mut m_group = 32; + let mut n_group = 32; + let mut k_simd = 32; + if c_elements > 10 ^ 6 { + m_group = 48; + n_group = 48; + } + // If K_simd is perfectly equal to matrix K, the compiler can elide a large + // amount of logic in the kernel. + if k >= 33 && k <= 40 { + k_simd = 40; + } else if is_half && k >= 73 && k >= 80 { + k_simd = 80; + } else if c_elements > 10 ^ 6 { + if k <= 16 { + k_simd = 16; + } else if k <= 24 { + k_simd = 24; + } else if k <= 32 { + k_simd = 32; + } else if k <= 48 { + k_simd = 24; + } else if k <= 64 { + k_simd = 32; + } else if is_float { + k_simd = 24; + } + } + + let m_splits = 2; + let n_splits = 2; + let m_simd = m_group / m_splits; + let n_simd = n_group / n_splits; + + let config = MFAGemmConfig { + m, + k, + n, + transpose_left, + transpose_right, + batched, + m_simd, + n_simd, + k_simd, + m_splits, + n_splits, + m_group, + n_group, + }; + + let pipeline = kernels.load_pipeline( + device, + Source::MetalFlashAttention, + KernelKey::new(name).with_constants(config.into()), + )?; + let block_type_size = if is_half { 2 } else { 4 }; + let a_block_bytes = m_group * k_simd * block_type_size; + let b_block_bytes = k_simd * n_group * block_type_size; + let c_block_bytes = m_group * n_group * block_type_size; + let mut thread_group_memory_length = a_block_bytes + b_block_bytes; + + if m % 8 > 0 && n % 8 > 0 { + thread_group_memory_length = core::cmp::max(thread_group_memory_length, c_block_bytes); + } - let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger); + encoder.use_resources(&[&lhs, &rhs], MTLResourceUsage::Read); + encoder.use_resource(&output, MTLResourceUsage::Write); + encoder.set_buffers(0, &[Some(lhs), Some(rhs), Some(output)], &[0; 3]); - let length: usize = shape.iter().product(); - set_params!( - encoder, - ( - length, - num_dims, - shape, - strides, - (input, offset), - (output, output_offset) - ) + let ceil_divide = |a, b| (a + b - 1) / b; + + let mut grid_z = 1; + + if batched { + grid_z = lhs_dims[..lhs_dims.len() - 2].iter().product(); + let byte_stride = |shape: &[usize]| -> u64 { + let rank = shape.len(); + let mut output = core::mem::size_of::() * shape[rank - 2] * shape[rank - 1]; + if shape[..shape.len() - 2].iter().product::() == 1 { + output = 0; + } + output as u64 + }; + let byte_stride_a = byte_stride(lhs_dims); + let byte_stride_b = byte_stride(rhs_dims); + let byte_stride_c = byte_stride(&[m, n]); + + type BatchConfig = (u64, u64, u64, u64); + let mut batching_conf: Vec = vec![]; + for i in 0..grid_z { + batching_conf.push(( + i as u64 * byte_stride_a, + i as u64 * byte_stride_b, + i as u64 * byte_stride_c, + 0u64, + )); + } + set_param(encoder, 10, batching_conf.as_slice()); + } + + let grid_size = MTLSize::new( + ceil_divide(n as NSUInteger, n_group as NSUInteger), + ceil_divide(m as NSUInteger, m_group as NSUInteger), + grid_z as NSUInteger, ); - let width: usize = shape.iter().product(); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1); + encoder.dispatch_thread_groups(grid_size, group_size); encoder.end_encoding(); Ok(()) }