mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Use metal encode_gemm
This commit is contained in:
@ -1,15 +1,15 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::bail;
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||
use crate::error::Error;
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use metal::mps::{Float32, MPSDataType};
|
||||
use metal::{Buffer, MTLResourceOptions};
|
||||
use metal::mps::matrix::encode_gemm;
|
||||
use metal::mps::Float32;
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Metal related errors
|
||||
@ -288,7 +288,7 @@ impl MetalStorage {
|
||||
let elem_count = b * m * n;
|
||||
match (self.dtype, rhs.dtype) {
|
||||
(DType::F32, DType::F32) => {
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
if b != 1 {
|
||||
println!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
@ -310,54 +310,28 @@ impl MetalStorage {
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
let m: u64 = m.try_into().expect("usize should fit u64");
|
||||
let n: u64 = n.try_into().expect("usize should fit u64");
|
||||
let k: u64 = k.try_into().expect("usize should fit u64");
|
||||
// Create descriptors
|
||||
let left_descriptor =
|
||||
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
|
||||
let right_descriptor =
|
||||
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
|
||||
let result_descriptor =
|
||||
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
|
||||
|
||||
encode_gemm::<Float32, Float32, Float32>(
|
||||
&self.device,
|
||||
&self.device.command_buffer,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&mut out_buffer,
|
||||
m as NSUInteger,
|
||||
n as NSUInteger,
|
||||
k as NSUInteger,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.map_err(|e| Error::Metal(e))?;
|
||||
|
||||
println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||
println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||
println!("out {:?} {m} {n}", out_buffer.length());
|
||||
// Create matrix objects
|
||||
let left_matrix =
|
||||
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||
.expect("Failed to create left matrix");
|
||||
let right_matrix =
|
||||
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
||||
.expect("Failed to create left matrix");
|
||||
|
||||
let result_matrix =
|
||||
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||
.expect("Failed to create left matrix");
|
||||
|
||||
println!("lhs {:?}", lhs_l.shape());
|
||||
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.expect("Failed to create matrix multiplication kernel");
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&self.device.command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
Ok(Self {
|
||||
buffer: Arc::new(out_buffer),
|
||||
device: self.device.clone(),
|
||||
|
Reference in New Issue
Block a user