From d9c1f7e2012d1dc29f613e60d19b4e0b49bf01bd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 20:09:25 +0100 Subject: [PATCH] Fixed matmul (display still broken without casting back to CPU first? ) --- Cargo.toml | 3 +- candle-core/src/metal_backend.rs | 231 +++++++++++---------- candle-metal-kernels/Cargo.toml | 3 +- candle-transformers/src/models/llama2_c.rs | 3 +- 4 files changed, 129 insertions(+), 111 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2a2bd9cb..b6517856 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,8 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../metal-rs", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6687534d..fefafb2f 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -19,6 +19,13 @@ pub enum MetalError { Message(String), #[error(transparent)] KernelError(#[from] candle_metal_kernels::MetalKernelError), + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, } impl From for MetalError { @@ -53,7 +60,7 @@ impl MetalDevice { // self.device.as_ref() // } - pub fn id(&self) -> u64 { + pub fn id(&self) -> NSUInteger { self.registry_id() } @@ -70,7 +77,7 @@ impl MetalDevice { } pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { - let size = (element_count * dtype.size_in_bytes()) as u64; + let size = (element_count * dtype.size_in_bytes()) as NSUInteger; // debug!("Allocate 1 - buffer size {size}"); self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) @@ -520,6 +527,7 @@ impl BackendStorage for MetalStorage { (left, right) => todo!("index select metal {left:?} {right:?}"), }; let command_buffer = self.device.command_queue.new_command_buffer(); + // println!("INDEX SELECT"); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -561,20 +569,117 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let transpose_left = false; - let transpose_right = !rhs_l.is_contiguous(); - let alpha = 1.0; - let beta = 0.0; - self.matmul_generic( - rhs, - (b, m, n, k), - lhs_l, - rhs_l, + // Create descriptors + use metal::mps::matrix::*; + let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; + let size = core::mem::size_of::() as NSUInteger; + + let elem_count = b * m * n; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // The a tensor has dims batching, k, n (rhs) + let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + // println!("{transpose_left} {transpose_right}"); + + let b = b as NSUInteger; + let m = m as NSUInteger; + let n = n as NSUInteger; + let k = k as NSUInteger; + + let left_descriptor = if transpose_left { + MatrixDescriptor::init_single(k, m, m * size, type_id) + } else { + MatrixDescriptor::init_single(m, k, k * size, type_id) + }; + let right_descriptor = if transpose_right { + MatrixDescriptor::init_single(n, k, k * size, type_id) + } else { + MatrixDescriptor::init_single(k, n, n * size, type_id) + }; + let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); + + // Create matrix objects + let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + let out_buffer = self.device.new_buffer(elem_count, self.dtype); + let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + let alpha = 1.0f64; + let beta = 0.0f64; + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( + &self.device, transpose_left, transpose_right, + m, + n, + k, alpha, beta, ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + matrix_multiplication.set_batch_size(b); + + // Encode kernel to command buffer + let command_buffer = self.device.command_queue.new_command_buffer(); + matrix_multiplication.encode_to_command_buffer( + command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + // let left = self.buffer.read_to_vec::(10); + // let right = rhs.buffer.read_to_vec::(10); + // let out = out_buffer.read_to_vec::(40); + // todo!("Out {left:?} {right:?} {out:?}"); + + Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { @@ -583,18 +688,6 @@ impl BackendStorage for MetalStorage { if el_count == 0 { return Ok(()); } - // todo!("Copy strided {:?}", src_l.is_contiguous()); - // if src_l.is_contiguous() { - // let command_buffer = self.device.command_queue.new_command_buffer(); - // let blip = command_buffer.new_blit_command_encoder(); - // blip.copy_from_buffer( - // &self.buffer, - // src_l.start_offset() as u64, - // &dst.buffer, - // dst_offset as u64, - // self.buffer.length(), - // ); - // } else { let command_buffer = self.device.command_queue.new_command_buffer(); let kernel_name = match self.dtype { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, @@ -631,84 +724,6 @@ impl MetalStorage { dtype, } } - pub(crate) fn matmul_generic( - &self, - rhs: &Self, - (b, m, n, k): (usize, usize, usize, usize), - lhs_l: &Layout, - rhs_l: &Layout, - transpose_left: bool, - transpose_right: bool, - alpha: f64, - beta: f64, - ) -> Result { - let elem_count = b * m * n; - match (self.dtype, rhs.dtype) { - (DType::F32, DType::F32) => { - let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); - // if b != 1 { - // // debug!("TODO implement batched matmul for B={b}"); - // crate::bail!("Didn't implemented strided matmul yet"); - // return Ok(Self { - // buffer: out_buffer, - // device: self.device.clone(), - // dtype: self.dtype(), - // }); - //} - // if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - // // debug!( - // // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", - // // lhs_l.is_contiguous(), - // // rhs_l.is_contiguous(), - // // rhs_l - // // ); - // crate::bail!("No not contiguous matmul"); - // return Ok(Self { - // buffer: out_buffer, - // device: self.device.clone(), - // dtype: self.dtype(), - // }); - // } - - // debug!("TODO GEMM"); - let command_buffer = self.device.command_queue.new_command_buffer(); - encode_gemm::( - &self.device, - &command_buffer, - transpose_left, - transpose_right, - &self.buffer, - &rhs.buffer, - &mut out_buffer, - m as NSUInteger, - n as NSUInteger, - k as NSUInteger, - alpha as f32, - beta as f32, - Some(b as NSUInteger), - ) - .map_err(MetalError::from)?; - - command_buffer.commit(); - command_buffer.wait_until_completed(); - // command_buffer.wait_until_scheduled(); - // - let left = self.buffer.read_to_vec::(10); - let right = rhs.buffer.read_to_vec::(10); - let out = out_buffer.read_to_vec::(10); - - println!("{b} {m} {n} {k} "); - println!("{left:?} {right:?} {out:?}"); - - Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) - } - _ => todo!("Unimplemented matmul for this pair"), - } - } pub fn buffer(&self) -> &Buffer { &self.buffer @@ -774,37 +789,37 @@ impl BackendDevice for MetalDevice { let buffer = match storage { CpuStorage::U8(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::U32(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::I64(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::BF16(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::F16(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::F32(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::F64(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), }; diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 2585ca62..2d2742ab 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,7 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../../metal-rs", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 24182b72..aba9a547 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -156,6 +156,7 @@ impl CausalSelfAttention { let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; let x0 = x.narrow(D::Minus1, 0, 1)?; let x1 = x.narrow(D::Minus1, 1, 1)?; + todo!("X {x1}"); let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; @@ -165,7 +166,6 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; - todo!("X {q}"); let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; @@ -174,6 +174,7 @@ impl CausalSelfAttention { let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let q = self.apply_rotary_emb(&q, index_pos)?; + todo!("X {q}"); let mut k = self.apply_rotary_emb(&k, index_pos)?; if self.cache.use_kv_cache {