mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module. * Use the BufferOffset for unary ops. * Fix clippy lints. * Use the new BufferOffset. * Adapt the binary ops. * Affine. * More ops (powf, elu, cast).
This commit is contained in:
@ -2,8 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels::CallConvTranspose2dCfg;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -12,6 +11,12 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
|
||||
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
BufferOffset {
|
||||
buffer,
|
||||
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
|
||||
}
|
||||
}
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -102,7 +107,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
@ -115,7 +121,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -134,9 +140,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -155,7 +160,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
@ -168,7 +174,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
pow as f32,
|
||||
)
|
||||
@ -186,9 +192,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
pow as f32,
|
||||
)
|
||||
@ -206,7 +211,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
@ -219,7 +225,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
alpha as f32,
|
||||
)
|
||||
@ -237,9 +243,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
alpha as f32,
|
||||
)
|
||||
@ -344,7 +349,8 @@ impl BackendStorage for MetalStorage {
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
@ -392,8 +398,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -420,9 +425,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -439,7 +443,8 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
@ -511,7 +516,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -556,17 +561,16 @@ impl BackendStorage for MetalStorage {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
0,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
@ -1358,17 +1362,20 @@ impl BackendStorage for MetalStorage {
|
||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
||||
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, src_l, self.dtype);
|
||||
let dst = BufferOffset {
|
||||
buffer: &dst.buffer,
|
||||
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst.buffer,
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy_strided");
|
||||
@ -1402,10 +1409,9 @@ impl MetalStorage {
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &op[..1] != "b"
|
||||
{
|
||||
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
|
||||
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
|
||||
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
@ -1486,8 +1492,8 @@ impl MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
lhs,
|
||||
rhs,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1585,12 +1591,10 @@ impl MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
Reference in New Issue
Block a user