Working version for llama2-c.

This commit is contained in:
Nicolas Patry
2023-11-13 12:36:27 +01:00
parent 6071797450
commit 79845bd93b

View File

@ -8,7 +8,6 @@ use core::mem;
use half::{bf16, f16};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::ops::Deref;
use std::sync::{Arc, RwLock};
/// Metal related errors
@ -191,8 +190,6 @@ impl BackendStorage for MetalStorage {
)
.unwrap();
}
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
return Ok(Self {
buffer,
device: device.clone(),
@ -260,8 +257,6 @@ impl BackendStorage for MetalStorage {
&mut buffer,
)
.map_err(MetalError::from)?;
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
@ -318,8 +313,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
}
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device: device.clone(),
@ -421,9 +414,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device: device.clone(),
@ -508,9 +498,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device: device.clone(),
@ -551,8 +538,6 @@ impl BackendStorage for MetalStorage {
&mut buffer,
)
.map_err(MetalError::from)?;
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device,
@ -663,8 +648,6 @@ impl BackendStorage for MetalStorage {
&mut buffer,
)
.map_err(MetalError::from)?;
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer,
device: device.clone(),
@ -694,8 +677,6 @@ impl BackendStorage for MetalStorage {
// Create descriptors
use metal::mps::matrix::*;
assert_eq!(self.dtype, rhs.dtype);
let (type_id, size) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
@ -739,6 +720,26 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let b = b as NSUInteger;
let m = m as NSUInteger;
@ -758,12 +759,13 @@ impl BackendStorage for MetalStorage {
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
bi * m * k * size,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
&left_descriptor,
)
.ok_or_else(|| {
@ -771,7 +773,7 @@ impl BackendStorage for MetalStorage {
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
bi * n * k * size,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
&right_descriptor,
)
.ok_or_else(|| {
@ -806,14 +808,12 @@ impl BackendStorage for MetalStorage {
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
command_buffer.deref(),
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
}
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer: out_buffer,
@ -849,8 +849,6 @@ impl BackendStorage for MetalStorage {
dst_offset * dst.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(())
}
}