Adding CMP

This commit is contained in:
Nicolas Patry
2023-12-17 22:32:25 +01:00
parent 0a6e0a8c9a
commit e4b0cc59f5
3 changed files with 140 additions and 85 deletions

View File

@ -549,8 +549,16 @@ impl BackendStorage for MetalStorage {
Ok(Self::new(buffer, device, dtype))
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
crate::bail!("cmp metal")
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
let name = match op {
CmpOp::Eq => "eq",
CmpOp::Ne => "ne",
CmpOp::Le => "le",
CmpOp::Ge => "ge",
CmpOp::Lt => "lt",
CmpOp::Gt => "gt",
};
self.binary(name, rhs, lhs_l, rhs_l)
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
@ -711,76 +719,7 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer()?;
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &B::KERNEL[..1] != "b"
{
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("add", DType::F16) => contiguous::add::HALF,
("sub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::binary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("badd", DType::F32) => strided::add::FLOAT,
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
("bminimum", DType::F32) => strided::min::FLOAT,
("bmaximum", DType::F32) => strided::max::FLOAT,
("badd", DType::F16) => strided::add::HALF,
("bsub", DType::F16) => strided::sub::HALF,
("bmul", DType::F16) => strided::mul::HALF,
("bdiv", DType::F16) => strided::div::HALF,
("bminimum", DType::F16) => strided::min::HALF,
("bmaximum", DType::F16) => strided::max::HALF,
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
}
command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype))
self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
}
fn where_cond(
@ -1043,6 +982,111 @@ impl MetalStorage {
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn binary(
&self,
op: &'static str,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &op[..1] != "b"
{
use candle_metal_kernels::binary::contiguous;
let (kernel_name, dtype) = match (op, self.dtype) {
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
("div", DType::F16) => (contiguous::div::HALF, self.dtype),
("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
("le", DType::F16) => (contiguous::le::HALF, DType::U8),
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
};
let buffer = device.new_buffer(el_count, dtype, op)?;
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
(buffer, dtype)
} else {
use candle_metal_kernels::binary::strided;
let (kernel_name, dtype) = match (op, self.dtype) {
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
("le", DType::F32) => (strided::le::FLOAT, DType::U8),
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
("eq", DType::F16) => (strided::eq::HALF, DType::U8),
("ne", DType::F16) => (strided::ne::HALF, DType::U8),
("le", DType::F16) => (strided::le::HALF, DType::U8),
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
(name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
};
let buffer = device.new_buffer(el_count, dtype, op)?;
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
(buffer, dtype)
};
command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype))
}
}
impl BackendDevice for MetalDevice {