mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Adding CMP
This commit is contained in:
@ -549,8 +549,16 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
crate::bail!("cmp metal")
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
let name = match op {
|
||||
CmpOp::Eq => "eq",
|
||||
CmpOp::Ne => "ne",
|
||||
CmpOp::Le => "le",
|
||||
CmpOp::Ge => "ge",
|
||||
CmpOp::Lt => "lt",
|
||||
CmpOp::Gt => "gt",
|
||||
};
|
||||
self.binary(name, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
@ -711,76 +719,7 @@ impl BackendStorage for MetalStorage {
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &B::KERNEL[..1] != "b"
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("add", DType::F16) => contiguous::add::HALF,
|
||||
("sub", DType::F16) => contiguous::sub::HALF,
|
||||
("mul", DType::F16) => contiguous::mul::HALF,
|
||||
("div", DType::F16) => contiguous::div::HALF,
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("badd", DType::F32) => strided::add::FLOAT,
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
("bminimum", DType::F32) => strided::min::FLOAT,
|
||||
("bmaximum", DType::F32) => strided::max::FLOAT,
|
||||
("badd", DType::F16) => strided::add::HALF,
|
||||
("bsub", DType::F16) => strided::sub::HALF,
|
||||
("bmul", DType::F16) => strided::mul::HALF,
|
||||
("bdiv", DType::F16) => strided::div::HALF,
|
||||
("bminimum", DType::F16) => strided::min::HALF,
|
||||
("bmaximum", DType::F16) => strided::max::HALF,
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn where_cond(
|
||||
@ -1043,6 +982,111 @@ impl MetalStorage {
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn binary(
|
||||
&self,
|
||||
op: &'static str,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &op[..1] != "b"
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
|
||||
("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
|
||||
("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
|
||||
("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
|
||||
("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
|
||||
("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
|
||||
("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
|
||||
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||
("div", DType::F16) => (contiguous::div::HALF, self.dtype),
|
||||
("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
|
||||
("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
|
||||
("le", DType::F16) => (contiguous::le::HALF, DType::U8),
|
||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
(buffer, dtype)
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
|
||||
("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
|
||||
("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
|
||||
("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
|
||||
("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
|
||||
("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
|
||||
("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
|
||||
("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
|
||||
("le", DType::F32) => (strided::le::FLOAT, DType::U8),
|
||||
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||
("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
|
||||
("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
|
||||
("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
|
||||
("eq", DType::F16) => (strided::eq::HALF, DType::U8),
|
||||
("ne", DType::F16) => (strided::ne::HALF, DType::U8),
|
||||
("le", DType::F16) => (strided::le::HALF, DType::U8),
|
||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
(name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
(buffer, dtype)
|
||||
};
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
|
@ -25,15 +25,15 @@ kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
device OUT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[thread_position_in_grid]; \
|
||||
TYPENAME y = right[thread_position_in_grid]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
TYPENAME x = left[tid]; \
|
||||
TYPENAME y = right[tid]; \
|
||||
output[tid] = OUT_TYPENAME(FN); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -43,15 +43,15 @@ kernel void FN_NAME_STRIDED( \
|
||||
constant size_t *right_strides, \
|
||||
device const TYPENAME *left, \
|
||||
device const TYPENAME *right, \
|
||||
device TYPENAME *output, \
|
||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
||||
device OUT_TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
||||
TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
|
||||
TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
|
||||
output[tid] = OUT_TYPENAME(FN); \
|
||||
}
|
||||
|
||||
#define BINARY_OP(FN, NAME) \
|
||||
@ -61,6 +61,10 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define BINARY_OP_OUT(NAME, FN) \
|
||||
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
|
||||
|
||||
|
||||
BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
@ -69,6 +73,13 @@ BINARY_OP(x / y, div)
|
||||
BINARY_OP(MIN(x, y), min)
|
||||
BINARY_OP(MAX(x, y), max)
|
||||
|
||||
BINARY_OP_OUT(eq, x == y)
|
||||
BINARY_OP_OUT(ne, x != y)
|
||||
BINARY_OP_OUT(le, x <= y)
|
||||
BINARY_OP_OUT(lt, x < y)
|
||||
BINARY_OP_OUT(ge, x >= y)
|
||||
BINARY_OP_OUT(gt, x > y)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
|
@ -166,7 +166,7 @@ pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div, min, max);
|
||||
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
|
Reference in New Issue
Block a user