mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Working version for llama2-c.
This commit is contained in:
@ -8,7 +8,6 @@ use core::mem;
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::ops::Deref;
|
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
@ -191,8 +190,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -260,8 +257,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&mut buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
@ -318,8 +313,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -421,9 +414,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -508,9 +498,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -551,8 +538,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&mut buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
@ -663,8 +648,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
&mut buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -694,8 +677,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
// Create descriptors
|
// Create descriptors
|
||||||
use metal::mps::matrix::*;
|
use metal::mps::matrix::*;
|
||||||
|
|
||||||
assert_eq!(self.dtype, rhs.dtype);
|
|
||||||
|
|
||||||
let (type_id, size) = match self.dtype {
|
let (type_id, size) = match self.dtype {
|
||||||
DType::F32 => (
|
DType::F32 => (
|
||||||
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
||||||
@ -739,6 +720,26 @@ impl BackendStorage for MetalStorage {
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
|
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?,
|
||||||
|
} as u64;
|
||||||
|
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?,
|
||||||
|
} as u64;
|
||||||
|
|
||||||
let b = b as NSUInteger;
|
let b = b as NSUInteger;
|
||||||
let m = m as NSUInteger;
|
let m = m as NSUInteger;
|
||||||
@ -758,12 +759,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||||
|
|
||||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
for bi in 0..b {
|
for bi in 0..b {
|
||||||
// Create matrix objects
|
// Create matrix objects
|
||||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
bi * m * k * size,
|
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
||||||
&left_descriptor,
|
&left_descriptor,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
@ -771,7 +773,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
})?;
|
})?;
|
||||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
bi * n * k * size,
|
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
||||||
&right_descriptor,
|
&right_descriptor,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
@ -806,14 +808,12 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
// Encode kernel to command buffer
|
// Encode kernel to command buffer
|
||||||
matrix_multiplication.encode_to_command_buffer(
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
command_buffer.deref(),
|
&command_buffer,
|
||||||
&left_matrix,
|
&left_matrix,
|
||||||
&right_matrix,
|
&right_matrix,
|
||||||
&result_matrix,
|
&result_matrix,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
@ -849,8 +849,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
dst_offset * dst.dtype.size_in_bytes(),
|
dst_offset * dst.dtype.size_in_bytes(),
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
// command_buffer.commit();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user