mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Working version for llama2-c.
This commit is contained in:
@ -8,7 +8,6 @@ use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::ops::Deref;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Metal related errors
|
||||
@ -191,8 +190,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -260,8 +257,6 @@ impl BackendStorage for MetalStorage {
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
@ -318,8 +313,6 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -421,9 +414,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -508,9 +498,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -551,8 +538,6 @@ impl BackendStorage for MetalStorage {
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
@ -663,8 +648,6 @@ impl BackendStorage for MetalStorage {
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -694,8 +677,6 @@ impl BackendStorage for MetalStorage {
|
||||
// Create descriptors
|
||||
use metal::mps::matrix::*;
|
||||
|
||||
assert_eq!(self.dtype, rhs.dtype);
|
||||
|
||||
let (type_id, size) = match self.dtype {
|
||||
DType::F32 => (
|
||||
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
||||
@ -739,6 +720,26 @@ impl BackendStorage for MetalStorage {
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
|
||||
let b = b as NSUInteger;
|
||||
let m = m as NSUInteger;
|
||||
@ -758,12 +759,13 @@ impl BackendStorage for MetalStorage {
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
for bi in 0..b {
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&self.buffer,
|
||||
bi * m * k * size,
|
||||
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
||||
&left_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
@ -771,7 +773,7 @@ impl BackendStorage for MetalStorage {
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&rhs.buffer,
|
||||
bi * n * k * size,
|
||||
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
||||
&right_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
@ -806,14 +808,12 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
command_buffer.deref(),
|
||||
&command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
}
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
@ -849,8 +849,6 @@ impl BackendStorage for MetalStorage {
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// command_buffer.commit();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user