From 79845bd93b0a4743ff39b44934ead166fa245f1e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 Nov 2023 12:36:27 +0100 Subject: [PATCH] Working version for llama2-c. --- candle-core/src/metal_backend.rs | 50 +++++++++++++++----------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index eadbf1f1..eca8e1fe 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,7 +8,6 @@ use core::mem; use half::{bf16, f16}; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; -use std::ops::Deref; use std::sync::{Arc, RwLock}; /// Metal related errors @@ -191,8 +190,6 @@ impl BackendStorage for MetalStorage { ) .unwrap(); } - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); return Ok(Self { buffer, device: device.clone(), @@ -260,8 +257,6 @@ impl BackendStorage for MetalStorage { &mut buffer, ) .map_err(MetalError::from)?; - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); Ok(Self { buffer, @@ -318,8 +313,6 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); Ok(Self { buffer, device: device.clone(), @@ -421,9 +414,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); - Ok(Self { buffer, device: device.clone(), @@ -508,9 +498,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); - Ok(Self { buffer, device: device.clone(), @@ -551,8 +538,6 @@ impl BackendStorage for MetalStorage { &mut buffer, ) .map_err(MetalError::from)?; - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); Ok(Self { buffer, device, @@ -663,8 +648,6 @@ impl BackendStorage for MetalStorage { &mut buffer, ) .map_err(MetalError::from)?; - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); Ok(Self { buffer, device: device.clone(), @@ -694,8 +677,6 @@ impl BackendStorage for MetalStorage { // Create descriptors use metal::mps::matrix::*; - assert_eq!(self.dtype, rhs.dtype); - let (type_id, size) = match self.dtype { DType::F32 => ( metal::mps::MPS_FLOATBIT_ENCODING | 32, @@ -739,6 +720,26 @@ impl BackendStorage for MetalStorage { mnk: (m, n, k), })? }; + let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, + [stride] => stride, + [] => m * k, + _ => Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?, + } as u64; + let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, + [stride] => stride, + [] => n * k, + _ => Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?, + } as u64; let b = b as NSUInteger; let m = m as NSUInteger; @@ -758,12 +759,13 @@ impl BackendStorage for MetalStorage { let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); let out_buffer = self.device.new_buffer(elem_count, self.dtype); + let command_buffer = self.device.command_buffer(); for bi in 0..b { // Create matrix objects let left_matrix = Matrix::init_with_buffer_descriptor( &self.buffer, - bi * m * k * size, + (bi * stride_left + lhs_l.start_offset() as u64) * size, &left_descriptor, ) .ok_or_else(|| { @@ -771,7 +773,7 @@ impl BackendStorage for MetalStorage { })?; let right_matrix = Matrix::init_with_buffer_descriptor( &rhs.buffer, - bi * n * k * size, + (bi * stride_right + rhs_l.start_offset() as u64) * size, &right_descriptor, ) .ok_or_else(|| { @@ -806,14 +808,12 @@ impl BackendStorage for MetalStorage { // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( - command_buffer.deref(), + &command_buffer, &left_matrix, &right_matrix, &result_matrix, ); } - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); Ok(Self { buffer: out_buffer, @@ -849,8 +849,6 @@ impl BackendStorage for MetalStorage { dst_offset * dst.dtype.size_in_bytes(), ) .map_err(MetalError::from)?; - // command_buffer.commit(); - // command_buffer.wait_until_scheduled(); Ok(()) } }