Rework the buffer offset logic for metal kernels (#2028)

* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
This commit is contained in:
Laurent Mazare
2024-04-07 22:37:53 +02:00
committed by GitHub
parent 7f354473cf
commit c5fe4a7f89
4 changed files with 305 additions and 286 deletions

View File

@ -2,8 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels::CallConvTranspose2dCfg;
use candle_metal_kernels::Kernels;
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
@ -12,6 +11,12 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
buffer,
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
}
}
/// Simple way to catch lock error without
/// depending on T
#[derive(thiserror::Error, Debug)]
@ -102,7 +107,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "affine")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
@ -115,7 +121,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
mul as f32,
add as f32,
@ -134,9 +140,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
mul as f32,
add as f32,
@ -155,7 +160,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "powf")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
@ -168,7 +174,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
pow as f32,
)
@ -186,9 +192,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
pow as f32,
)
@ -206,7 +211,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "elu")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
@ -219,7 +225,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
alpha as f32,
)
@ -237,9 +243,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
alpha as f32,
)
@ -344,7 +349,8 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
let command_buffer = device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
@ -392,8 +398,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -420,9 +425,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@ -439,7 +443,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label(B::KERNEL);
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@ -511,7 +516,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -556,17 +561,16 @@ impl BackendStorage for MetalStorage {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
};
let dst = BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
0,
dst,
)
.map_err(MetalError::from)?;
}
@ -1358,17 +1362,20 @@ impl BackendStorage for MetalStorage {
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, src_l, self.dtype);
let dst = BufferOffset {
buffer: &dst.buffer,
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
&self.buffer,
src,
src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(),
&dst.buffer,
dst_offset * dst.dtype.size_in_bytes(),
dst,
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
@ -1402,10 +1409,9 @@ impl MetalStorage {
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &op[..1] != "b"
{
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
use candle_metal_kernels::binary::contiguous;
let (kernel_name, dtype) = match (op, self.dtype) {
@ -1486,8 +1492,8 @@ impl MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
lhs,
rhs,
&buffer,
)
.map_err(MetalError::from)?;
@ -1585,12 +1591,10 @@ impl MetalStorage {
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
lhs,
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
rhs,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;