Adding CMP

This commit is contained in:
Nicolas Patry
2023-12-17 22:32:25 +01:00
parent 0a6e0a8c9a
commit e4b0cc59f5
3 changed files with 140 additions and 85 deletions

View File

@ -549,8 +549,16 @@ impl BackendStorage for MetalStorage {
Ok(Self::new(buffer, device, dtype))
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
crate::bail!("cmp metal")
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
let name = match op {
CmpOp::Eq => "eq",
CmpOp::Ne => "ne",
CmpOp::Le => "le",
CmpOp::Ge => "ge",
CmpOp::Lt => "lt",
CmpOp::Gt => "gt",
};
self.binary(name, rhs, lhs_l, rhs_l)
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
@ -711,76 +719,7 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer()?;
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &B::KERNEL[..1] != "b"
{
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("add", DType::F16) => contiguous::add::HALF,
("sub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::binary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("badd", DType::F32) => strided::add::FLOAT,
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
("bminimum", DType::F32) => strided::min::FLOAT,
("bmaximum", DType::F32) => strided::max::FLOAT,
("badd", DType::F16) => strided::add::HALF,
("bsub", DType::F16) => strided::sub::HALF,
("bmul", DType::F16) => strided::mul::HALF,
("bdiv", DType::F16) => strided::div::HALF,
("bminimum", DType::F16) => strided::min::HALF,
("bmaximum", DType::F16) => strided::max::HALF,
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
}
command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype))
self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
}
fn where_cond(
@ -1043,6 +982,111 @@ impl MetalStorage {
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn binary(
&self,
op: &'static str,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &op[..1] != "b"
{
use candle_metal_kernels::binary::contiguous;
let (kernel_name, dtype) = match (op, self.dtype) {
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
("div", DType::F16) => (contiguous::div::HALF, self.dtype),
("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
("le", DType::F16) => (contiguous::le::HALF, DType::U8),
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
let buffer = device.new_buffer(el_count, dtype, op)?;
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
(buffer, dtype)
} else {
use candle_metal_kernels::binary::strided;
let (kernel_name, dtype) = match (op, self.dtype) {
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
("le", DType::F32) => (strided::le::FLOAT, DType::U8),
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
("eq", DType::F16) => (strided::eq::HALF, DType::U8),
("ne", DType::F16) => (strided::ne::HALF, DType::U8),
("le", DType::F16) => (strided::le::HALF, DType::U8),
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
(name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
};
let buffer = device.new_buffer(el_count, dtype, op)?;
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
(buffer, dtype)
};
command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype))
}
}
impl BackendDevice for MetalDevice {

View File

@ -25,15 +25,15 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
device OUT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
if (tid >= dim) { \
return; \
} \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
TYPENAME x = left[tid]; \
TYPENAME y = right[tid]; \
output[tid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -43,15 +43,15 @@ kernel void FN_NAME_STRIDED( \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
device OUT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (thread_position_in_grid >= dim) { \
if (tid >= dim) { \
return; \
} \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
output[tid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \
@ -61,6 +61,10 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
@ -69,6 +73,13 @@ BINARY_OP(x / y, div)
BINARY_OP(MIN(x, y), min)
BINARY_OP(MAX(x, y), max)
BINARY_OP_OUT(eq, x == y)
BINARY_OP_OUT(ne, x != y)
BINARY_OP_OUT(le, x <= y)
BINARY_OP_OUT(lt, x < y)
BINARY_OP_OUT(ge, x >= y)
BINARY_OP_OUT(gt, x > y)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)

View File

@ -166,7 +166,7 @@ pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
}
pub mod binary {
ops!(add, sub, mul, div, min, max);
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
}
#[derive(thiserror::Error, Debug)]