mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Use metal encode_gemm
This commit is contained in:
@ -1,15 +1,15 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::bail;
|
|
||||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||||
|
use crate::error::Error;
|
||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use core::mem;
|
use core::mem;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
use metal::mps::matrix::encode_gemm;
|
||||||
use metal::mps::{Float32, MPSDataType};
|
use metal::mps::Float32;
|
||||||
use metal::{Buffer, MTLResourceOptions};
|
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
@ -288,7 +288,7 @@ impl MetalStorage {
|
|||||||
let elem_count = b * m * n;
|
let elem_count = b * m * n;
|
||||||
match (self.dtype, rhs.dtype) {
|
match (self.dtype, rhs.dtype) {
|
||||||
(DType::F32, DType::F32) => {
|
(DType::F32, DType::F32) => {
|
||||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
if b != 1 {
|
if b != 1 {
|
||||||
println!("TODO implement batched matmul for B={b}");
|
println!("TODO implement batched matmul for B={b}");
|
||||||
// bail!("Didn't implemented strided matmul yet");
|
// bail!("Didn't implemented strided matmul yet");
|
||||||
@ -310,54 +310,28 @@ impl MetalStorage {
|
|||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let m: u64 = m.try_into().expect("usize should fit u64");
|
|
||||||
let n: u64 = n.try_into().expect("usize should fit u64");
|
encode_gemm::<Float32, Float32, Float32>(
|
||||||
let k: u64 = k.try_into().expect("usize should fit u64");
|
&self.device,
|
||||||
// Create descriptors
|
&self.device.command_buffer,
|
||||||
let left_descriptor =
|
transpose_left,
|
||||||
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
|
transpose_right,
|
||||||
let right_descriptor =
|
&self.buffer,
|
||||||
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
|
&rhs.buffer,
|
||||||
let result_descriptor =
|
&mut out_buffer,
|
||||||
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
|
m as NSUInteger,
|
||||||
|
n as NSUInteger,
|
||||||
|
k as NSUInteger,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
)
|
||||||
|
.map_err(|e| Error::Metal(e))?;
|
||||||
|
|
||||||
println!("lhs {:?} {m} {k}", self.buffer.length());
|
println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||||
println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||||
println!("out {:?} {m} {n}", out_buffer.length());
|
println!("out {:?} {m} {n}", out_buffer.length());
|
||||||
// Create matrix objects
|
|
||||||
let left_matrix =
|
|
||||||
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
|
||||||
.expect("Failed to create left matrix");
|
|
||||||
let right_matrix =
|
|
||||||
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
|
||||||
.expect("Failed to create left matrix");
|
|
||||||
|
|
||||||
let result_matrix =
|
|
||||||
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
|
||||||
.expect("Failed to create left matrix");
|
|
||||||
|
|
||||||
println!("lhs {:?}", lhs_l.shape());
|
println!("lhs {:?}", lhs_l.shape());
|
||||||
|
|
||||||
// Create kernel
|
|
||||||
let matrix_multiplication = MatrixMultiplication::init(
|
|
||||||
&self.device,
|
|
||||||
transpose_left,
|
|
||||||
transpose_right,
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
alpha,
|
|
||||||
beta,
|
|
||||||
)
|
|
||||||
.expect("Failed to create matrix multiplication kernel");
|
|
||||||
|
|
||||||
// Encode kernel to command buffer
|
|
||||||
matrix_multiplication.encode_to_command_buffer(
|
|
||||||
&self.device.command_buffer,
|
|
||||||
&left_matrix,
|
|
||||||
&right_matrix,
|
|
||||||
&result_matrix,
|
|
||||||
);
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: Arc::new(out_buffer),
|
buffer: Arc::new(out_buffer),
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
|
Reference in New Issue
Block a user