mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Stash for debugging
This commit is contained in:
@ -795,14 +795,16 @@ impl BackendStorage for MetalStorage {
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// Create descriptors
|
||||
let (type_id, size) = match self.dtype {
|
||||
let (type_id, size, name) = match self.dtype {
|
||||
DType::F32 => (
|
||||
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
||||
core::mem::size_of::<f32>() as NSUInteger,
|
||||
"sgemm",
|
||||
),
|
||||
DType::F16 => (
|
||||
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
||||
core::mem::size_of::<f16>() as NSUInteger,
|
||||
"hgemm",
|
||||
),
|
||||
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
||||
};
|
||||
@ -836,60 +838,37 @@ impl BackendStorage for MetalStorage {
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let b = b as NSUInteger;
|
||||
let m = m as NSUInteger;
|
||||
let n = n as NSUInteger;
|
||||
let k = k as NSUInteger;
|
||||
|
||||
let left_matrix = self.matrix(
|
||||
(b, m, k),
|
||||
transpose_left,
|
||||
size,
|
||||
lhs_l.start_offset() as NSUInteger * size,
|
||||
type_id,
|
||||
)?;
|
||||
let right_matrix = rhs.matrix(
|
||||
(b, k, n),
|
||||
transpose_right,
|
||||
size,
|
||||
rhs_l.start_offset() as NSUInteger * size,
|
||||
type_id,
|
||||
)?;
|
||||
let (result_matrix, out_buffer) =
|
||||
self.device
|
||||
.new_matrix((b, m, n), size, type_id, self.dtype)?;
|
||||
let result_buffer = self.device.new_buffer(b * m * n, self.dtype);
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
command_buffer.set_label("mfa gemm");
|
||||
|
||||
candle_metal_kernels::call_mfa_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
&self.buffer,
|
||||
lhs_l.shape().dims(),
|
||||
&rhs.buffer,
|
||||
rhs_l.shape().dims(),
|
||||
&result_buffer,
|
||||
(b, m, n, k),
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
command_buffer.set_label("matmul");
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
|
||||
Ok(Self::new(out_buffer, self.device.clone(), self.dtype()))
|
||||
Ok(Self::new(
|
||||
self.buffer.clone(),
|
||||
self.device.clone(),
|
||||
self.dtype(),
|
||||
))
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
|
Reference in New Issue
Block a user