mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Adding CMP
This commit is contained in:
@ -549,8 +549,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
Ok(Self::new(buffer, device, dtype))
|
Ok(Self::new(buffer, device, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||||
crate::bail!("cmp metal")
|
let name = match op {
|
||||||
|
CmpOp::Eq => "eq",
|
||||||
|
CmpOp::Ne => "ne",
|
||||||
|
CmpOp::Le => "le",
|
||||||
|
CmpOp::Ge => "ge",
|
||||||
|
CmpOp::Lt => "lt",
|
||||||
|
CmpOp::Gt => "gt",
|
||||||
|
};
|
||||||
|
self.binary(name, rhs, lhs_l, rhs_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
@ -711,76 +719,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device();
|
self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
|
||||||
let dtype = self.dtype;
|
|
||||||
let shape = lhs_l.shape();
|
|
||||||
let el_count = shape.elem_count();
|
|
||||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
|
||||||
let command_buffer = device.command_buffer()?;
|
|
||||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
|
||||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
|
||||||
&& &B::KERNEL[..1] != "b"
|
|
||||||
{
|
|
||||||
use candle_metal_kernels::binary::contiguous;
|
|
||||||
|
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
|
||||||
("add", DType::F32) => contiguous::add::FLOAT,
|
|
||||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
|
||||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
|
||||||
("div", DType::F32) => contiguous::div::FLOAT,
|
|
||||||
("add", DType::F16) => contiguous::add::HALF,
|
|
||||||
("sub", DType::F16) => contiguous::sub::HALF,
|
|
||||||
("mul", DType::F16) => contiguous::mul::HALF,
|
|
||||||
("div", DType::F16) => contiguous::div::HALF,
|
|
||||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
|
||||||
&device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&device.kernels,
|
|
||||||
kernel_name,
|
|
||||||
el_count,
|
|
||||||
&self.buffer,
|
|
||||||
&rhs.buffer,
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
} else {
|
|
||||||
use candle_metal_kernels::binary::strided;
|
|
||||||
|
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
|
||||||
("badd", DType::F32) => strided::add::FLOAT,
|
|
||||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
|
||||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
|
||||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
|
||||||
("bminimum", DType::F32) => strided::min::FLOAT,
|
|
||||||
("bmaximum", DType::F32) => strided::max::FLOAT,
|
|
||||||
("badd", DType::F16) => strided::add::HALF,
|
|
||||||
("bsub", DType::F16) => strided::sub::HALF,
|
|
||||||
("bmul", DType::F16) => strided::mul::HALF,
|
|
||||||
("bdiv", DType::F16) => strided::div::HALF,
|
|
||||||
("bminimum", DType::F16) => strided::min::HALF,
|
|
||||||
("bmaximum", DType::F16) => strided::max::HALF,
|
|
||||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
|
||||||
};
|
|
||||||
candle_metal_kernels::call_binary_strided(
|
|
||||||
&device.device,
|
|
||||||
&command_buffer,
|
|
||||||
&device.kernels,
|
|
||||||
kernel_name,
|
|
||||||
lhs_l.dims(),
|
|
||||||
&self.buffer,
|
|
||||||
lhs_l.stride(),
|
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&rhs.buffer,
|
|
||||||
rhs_l.stride(),
|
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
}
|
|
||||||
command_buffer.set_label("binary");
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn where_cond(
|
fn where_cond(
|
||||||
@ -1043,6 +982,111 @@ impl MetalStorage {
|
|||||||
pub fn buffer(&self) -> &Buffer {
|
pub fn buffer(&self) -> &Buffer {
|
||||||
&self.buffer
|
&self.buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn binary(
|
||||||
|
&self,
|
||||||
|
op: &'static str,
|
||||||
|
rhs: &Self,
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let device = self.device();
|
||||||
|
let shape = lhs_l.shape();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||||
|
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||||
|
&& &op[..1] != "b"
|
||||||
|
{
|
||||||
|
use candle_metal_kernels::binary::contiguous;
|
||||||
|
|
||||||
|
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||||
|
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
|
||||||
|
("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
|
||||||
|
("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
|
||||||
|
("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
|
||||||
|
("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
|
||||||
|
("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
|
||||||
|
("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
|
||||||
|
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||||
|
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||||
|
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||||
|
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||||
|
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||||
|
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||||
|
("div", DType::F16) => (contiguous::div::HALF, self.dtype),
|
||||||
|
("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
|
||||||
|
("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
|
||||||
|
("le", DType::F16) => (contiguous::le::HALF, DType::U8),
|
||||||
|
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||||
|
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||||
|
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||||
|
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||||
|
};
|
||||||
|
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||||
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
&self.buffer,
|
||||||
|
&rhs.buffer,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
(buffer, dtype)
|
||||||
|
} else {
|
||||||
|
use candle_metal_kernels::binary::strided;
|
||||||
|
|
||||||
|
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||||
|
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
|
||||||
|
("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
|
||||||
|
("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
|
||||||
|
("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
|
||||||
|
("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
|
||||||
|
("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
|
||||||
|
("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
|
||||||
|
("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
|
||||||
|
("le", DType::F32) => (strided::le::FLOAT, DType::U8),
|
||||||
|
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||||
|
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||||
|
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||||
|
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||||
|
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||||
|
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||||
|
("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
|
||||||
|
("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
|
||||||
|
("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
|
||||||
|
("eq", DType::F16) => (strided::eq::HALF, DType::U8),
|
||||||
|
("ne", DType::F16) => (strided::ne::HALF, DType::U8),
|
||||||
|
("le", DType::F16) => (strided::le::HALF, DType::U8),
|
||||||
|
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||||
|
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||||
|
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||||
|
(name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
|
||||||
|
};
|
||||||
|
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||||
|
candle_metal_kernels::call_binary_strided(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
lhs_l.dims(),
|
||||||
|
&self.buffer,
|
||||||
|
lhs_l.stride(),
|
||||||
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&rhs.buffer,
|
||||||
|
rhs_l.stride(),
|
||||||
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
(buffer, dtype)
|
||||||
|
};
|
||||||
|
command_buffer.set_label("binary");
|
||||||
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendDevice for MetalDevice {
|
impl BackendDevice for MetalDevice {
|
||||||
|
@ -25,15 +25,15 @@ kernel void FN_NAME( \
|
|||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
device const TYPENAME *left, \
|
device const TYPENAME *left, \
|
||||||
device const TYPENAME *right, \
|
device const TYPENAME *right, \
|
||||||
device TYPENAME *output, \
|
device OUT_TYPENAME *output, \
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (thread_position_in_grid >= dim) { \
|
if (tid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
TYPENAME x = left[thread_position_in_grid]; \
|
TYPENAME x = left[tid]; \
|
||||||
TYPENAME y = right[thread_position_in_grid]; \
|
TYPENAME y = right[tid]; \
|
||||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
output[tid] = OUT_TYPENAME(FN); \
|
||||||
}\
|
}\
|
||||||
kernel void FN_NAME_STRIDED( \
|
kernel void FN_NAME_STRIDED( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -43,15 +43,15 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
constant size_t *right_strides, \
|
constant size_t *right_strides, \
|
||||||
device const TYPENAME *left, \
|
device const TYPENAME *left, \
|
||||||
device const TYPENAME *right, \
|
device const TYPENAME *right, \
|
||||||
device TYPENAME *output, \
|
device OUT_TYPENAME *output, \
|
||||||
uint thread_position_in_grid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (thread_position_in_grid >= dim) { \
|
if (tid >= dim) { \
|
||||||
return; \
|
return; \
|
||||||
} \
|
} \
|
||||||
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
|
TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
|
||||||
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
|
TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
|
||||||
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
|
output[tid] = OUT_TYPENAME(FN); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BINARY_OP(FN, NAME) \
|
#define BINARY_OP(FN, NAME) \
|
||||||
@ -61,6 +61,10 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
|
|||||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||||
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||||
|
|
||||||
|
#define BINARY_OP_OUT(NAME, FN) \
|
||||||
|
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
|
||||||
|
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
|
||||||
|
|
||||||
|
|
||||||
BINARY_OP(x + y, add)
|
BINARY_OP(x + y, add)
|
||||||
BINARY_OP(x - y, sub)
|
BINARY_OP(x - y, sub)
|
||||||
@ -69,6 +73,13 @@ BINARY_OP(x / y, div)
|
|||||||
BINARY_OP(MIN(x, y), min)
|
BINARY_OP(MIN(x, y), min)
|
||||||
BINARY_OP(MAX(x, y), max)
|
BINARY_OP(MAX(x, y), max)
|
||||||
|
|
||||||
|
BINARY_OP_OUT(eq, x == y)
|
||||||
|
BINARY_OP_OUT(ne, x != y)
|
||||||
|
BINARY_OP_OUT(le, x <= y)
|
||||||
|
BINARY_OP_OUT(lt, x < y)
|
||||||
|
BINARY_OP_OUT(ge, x >= y)
|
||||||
|
BINARY_OP_OUT(gt, x > y)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_BINARY_OP(x + y, add)
|
BFLOAT_BINARY_OP(x + y, add)
|
||||||
BFLOAT_BINARY_OP(x - y, sub)
|
BFLOAT_BINARY_OP(x - y, sub)
|
||||||
|
@ -166,7 +166,7 @@ pub mod unary {
|
|||||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div, min, max);
|
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
Reference in New Issue
Block a user