From 6d4c8c07074d01934d0e0bdf1932f6a9c48386ef Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 6 Nov 2023 03:27:22 +0100 Subject: [PATCH] Use metal encode_gemm --- candle-core/src/metal_backend.rs | 68 ++++++++++---------------------- 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 1850cc8f..377e1406 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1,15 +1,15 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::bail; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; +use crate::error::Error; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use core::mem; use half::{bf16, f16}; use metal; -use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; -use metal::mps::{Float32, MPSDataType}; -use metal::{Buffer, MTLResourceOptions}; +use metal::mps::matrix::encode_gemm; +use metal::mps::Float32; +use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::sync::Arc; /// Metal related errors @@ -288,7 +288,7 @@ impl MetalStorage { let elem_count = b * m * n; match (self.dtype, rhs.dtype) { (DType::F32, DType::F32) => { - let out_buffer = self.device.new_buffer(elem_count, self.dtype); + let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); if b != 1 { println!("TODO implement batched matmul for B={b}"); // bail!("Didn't implemented strided matmul yet"); @@ -310,54 +310,28 @@ impl MetalStorage { dtype: self.dtype(), }); } - let m: u64 = m.try_into().expect("usize should fit u64"); - let n: u64 = n.try_into().expect("usize should fit u64"); - let k: u64 = k.try_into().expect("usize should fit u64"); - // Create descriptors - let left_descriptor = - MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID); - let right_descriptor = - MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID); - let result_descriptor = - MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID); + + encode_gemm::( + &self.device, + &self.device.command_buffer, + transpose_left, + transpose_right, + &self.buffer, + &rhs.buffer, + &mut out_buffer, + m as NSUInteger, + n as NSUInteger, + k as NSUInteger, + alpha, + beta, + ) + .map_err(|e| Error::Metal(e))?; println!("lhs {:?} {m} {k}", self.buffer.length()); println!("rhs {:?} {k} {n}", rhs.buffer.length()); println!("out {:?} {m} {n}", out_buffer.length()); - // Create matrix objects - let left_matrix = - Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor) - .expect("Failed to create left matrix"); - let right_matrix = - Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor) - .expect("Failed to create left matrix"); - - let result_matrix = - Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor) - .expect("Failed to create left matrix"); - println!("lhs {:?}", lhs_l.shape()); - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .expect("Failed to create matrix multiplication kernel"); - - // Encode kernel to command buffer - matrix_multiplication.encode_to_command_buffer( - &self.device.command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); Ok(Self { buffer: Arc::new(out_buffer), device: self.device.clone(),