diff --git a/Cargo.toml b/Cargo.toml index 7c2e3a7d..1d91be87 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { version = "0.27.0", features = ["mps"]} +metal = { version = "0.27.0", features = ["mps"], package = "candle-metal" } [profile.release-with-debug] inherits = "release" diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 012695dd..5a764fff 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,7 +10,7 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { version = "0.27.0", features = ["mps"]} +metal = { version = "0.27.0", features = ["mps"], package="candle-metal" } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 60f9b8a6..4c640140 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -5,6 +5,7 @@ use metal::{ use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; +use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); @@ -1052,123 +1053,206 @@ pub fn call_gemm( mnk: (m, n, k), })?; }; - let d_trans = false; - let alpha = 1.0f32; - let beta = 0.0f32; - let batched = b > 1; - let fused_activation = false; - let fused_bias = false; - let m_simd = 16; - let n_simd = 16; - let k_simd = 16; - let m_splits = 2; - let n_splits = 2; - let constants = Some(ConstantValues::new(vec![ - (0, Value::USize(m)), - (1, Value::USize(n)), - (2, Value::USize(k)), - (10, Value::Bool(a_trans)), - (11, Value::Bool(b_trans)), - (13, Value::Bool(d_trans)), - (20, Value::F32(alpha)), - (21, Value::F32(beta)), - (100, Value::Bool(batched)), - (101, Value::Bool(fused_activation)), - // Garbage - (102, Value::Bool(false)), - (103, Value::Bool(false)), - (113, Value::Bool(false)), - (50_000, Value::Bool(false)), - // End garbage - (200, Value::U16(m_simd)), - (201, Value::U16(n_simd)), - (202, Value::U16(k_simd)), - (210, Value::U16(m_splits)), - (211, Value::U16(n_splits)), - (50_001, Value::Bool(fused_bias)), - ])); - let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; - let m_group = m_simd * m_splits; - let n_group = n_simd * n_splits; + // let d_trans = false; + // let alpha = 1.0f32; + // let beta = 0.0f32; + // let batched = b > 1; + // let fused_activation = false; + // let fused_bias = false; + // let m_simd = 16; + // let n_simd = 16; + // let k_simd = 16; + // let m_splits = 2; + // let n_splits = 2; + // let constants = Some(ConstantValues::new(vec![ + // (0, Value::USize(m)), + // (1, Value::USize(n)), + // (2, Value::USize(k)), + // (10, Value::Bool(a_trans)), + // (11, Value::Bool(b_trans)), + // (13, Value::Bool(d_trans)), + // (20, Value::F32(alpha)), + // (21, Value::F32(beta)), + // (100, Value::Bool(batched)), + // (101, Value::Bool(fused_activation)), + // // Garbage + // (102, Value::Bool(false)), + // (103, Value::Bool(false)), + // (113, Value::Bool(false)), + // (50_000, Value::Bool(false)), + // // End garbage + // (200, Value::U16(m_simd)), + // (201, Value::U16(n_simd)), + // (202, Value::U16(k_simd)), + // (210, Value::U16(m_splits)), + // (211, Value::U16(n_splits)), + // (50_001, Value::Bool(fused_bias)), + // ])); + // let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + // let m_group = m_simd * m_splits; + // let n_group = n_simd * n_splits; + // + // let a_block_length = m_group * k_simd; + // let b_block_length = k_simd * n_group; + // + // let mut block_elements = a_block_length + b_block_length; + // if (m % 8 != 0) && (n % 8 != 0) { + // let c_block_length = m_group * n_group; + // block_elements = std::cmp::max(c_block_length, block_elements) + // } + // if fused_bias { + // if d_trans { + // block_elements = std::cmp::max(block_elements, m_group); + // } else { + // block_elements = std::cmp::max(block_elements, n_group); + // } + // } + // let bytes = match name { + // "sgemm" => 4, + // "hgemm" => 2, + // other => { + // return Err(MetalKernelError::LoadLibraryError(format!( + // "{other} is not a valid kernel for gemm" + // ))); + // } + // }; + // let block_bytes = block_elements * bytes; + // + // let encoder = command_buffer.new_compute_command_encoder(); + // encoder.wait_for_fence(&kernels.fence); + // encoder.set_compute_pipeline_state(&pipeline); + // encoder.set_threadgroup_memory_length(0, block_bytes.into()); + // encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + // encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + // encoder.set_buffer(2, Some(output), 0); + // // TODO Tensor D + // + // let grid_z = b; + // if batched { + // let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + // let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + // let byte_stride_c = m * n * bytes as usize; + // // TODO byte_stride_d + // let byte_stride_d = 0; + // + // let mut buffer: Vec = Vec::with_capacity(b * 4); + // for i in 0..b { + // buffer.push((i * byte_stride_a) as u64); + // buffer.push((i * byte_stride_b) as u64); + // buffer.push((i * byte_stride_c) as u64); + // buffer.push((i * byte_stride_d) as u64); + // } + // encoder.set_bytes( + // 10, + // (buffer.len() * core::mem::size_of::()) as NSUInteger, + // buffer.as_ptr() as *const NSUInteger as *const c_void, + // ); + // } + // + // let grid_size = MTLSize { + // width: divide(n, n_group.into()), + // height: divide(m, m_group.into()), + // depth: grid_z as NSUInteger, + // }; + // let group_size = MTLSize { + // width: 32 * (m_splits as u64) * (n_splits as u64), + // height: 1, + // depth: 1, + // }; + // // println!("grid size {grid_size:?} group size {group_size:?}"); + // encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + // encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + // encoder.use_resource(output, metal::MTLResourceUsage::Write); + // encoder.dispatch_thread_groups(grid_size, group_size); + // encoder.update_fence(&kernels.fence); + // encoder.end_encoding(); - let a_block_length = m_group * k_simd; - let b_block_length = k_simd * n_group; + let (b, m, n, k) = ( + b as NSUInteger, + m as NSUInteger, + n as NSUInteger, + k as NSUInteger, + ); - let mut block_elements = a_block_length + b_block_length; - if (m % 8 != 0) && (n % 8 != 0) { - let c_block_length = m_group * n_group; - block_elements = std::cmp::max(c_block_length, block_elements) - } - if fused_bias { - if d_trans { - block_elements = std::cmp::max(block_elements, m_group); - } else { - block_elements = std::cmp::max(block_elements, n_group); - } - } - let bytes = match name { - "sgemm" => 4, - "hgemm" => 2, - other => { - return Err(MetalKernelError::LoadLibraryError(format!( - "{other} is not a valid kernel for gemm" - ))); - } - }; - let block_bytes = block_elements * bytes; + let (size, data_type) = if name == "sgemm" { (4, 0x10000000 | 32) } else { (2, 0x10000000 | 16) }; - let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, block_bytes.into()); - encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); - encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); - encoder.set_buffer(2, Some(output), 0); - // TODO Tensor D + let left_matrix = create_matrix( + lhs_buffer, + (b, m, k), + a_trans, + size, + lhs_offset as NSUInteger, + data_type, + ).unwrap(); - let grid_z = b; - if batched { - let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; - let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; - let byte_stride_c = m * n * bytes as usize; - // TODO byte_stride_d - let byte_stride_d = 0; + let right_matrix = create_matrix( + rhs_buffer, + (b, k, n), + b_trans, + size, + rhs_offset as NSUInteger, + data_type, + ).unwrap(); - let mut buffer: Vec = Vec::with_capacity(b * 4); - for i in 0..b { - buffer.push((i * byte_stride_a) as u64); - buffer.push((i * byte_stride_b) as u64); - buffer.push((i * byte_stride_c) as u64); - buffer.push((i * byte_stride_d) as u64); - } - encoder.set_bytes( - 10, - (buffer.len() * core::mem::size_of::()) as NSUInteger, - buffer.as_ptr() as *const NSUInteger as *const c_void, - ); - } + let result_matrix = create_matrix( + output, + (b, m, n), + false, + size, + 0, + data_type, + ).unwrap(); - let grid_size = MTLSize { - width: divide(n, n_group.into()), - height: divide(m, m_group.into()), - depth: grid_z as NSUInteger, - }; - let group_size = MTLSize { - width: 32 * (m_splits as u64) * (n_splits as u64), - height: 1, - depth: 1, - }; - // println!("grid size {grid_size:?} group size {group_size:?}"); - encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(grid_size, group_size); - encoder.update_fence(&kernels.fence); - encoder.end_encoding(); + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( + &device, + a_trans, + b_trans, + m, + n, + k, + 1.0, + 0.0, + ).unwrap(); + + matrix_multiplication.encode_to_command_buffer( + command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); Ok(()) } +fn create_matrix( + buffer: &Buffer, + (b, rows, columns): (NSUInteger, NSUInteger, NSUInteger), + transpose: bool, + size: NSUInteger, + offset: NSUInteger, + data_type: u32, +) -> Option { + let (rows, columns) = if transpose { + (columns, rows) + } else { + (rows, columns) + }; + let descriptor = if b == 1 { + MatrixDescriptor::init_single(rows, columns, columns * size, data_type) + } else { + MatrixDescriptor::init_multiple( + rows, + columns, + b, + columns * size, + rows * columns * size, + data_type, + ) + }; + return Matrix::init_with_buffer_descriptor(&buffer, offset * size, &descriptor); +} + fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger }