mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Adding CMP
This commit is contained in:
@ -549,8 +549,16 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
crate::bail!("cmp metal")
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
let name = match op {
|
||||
CmpOp::Eq => "eq",
|
||||
CmpOp::Ne => "ne",
|
||||
CmpOp::Le => "le",
|
||||
CmpOp::Ge => "ge",
|
||||
CmpOp::Lt => "lt",
|
||||
CmpOp::Gt => "gt",
|
||||
};
|
||||
self.binary(name, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
@ -711,76 +719,7 @@ impl BackendStorage for MetalStorage {
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &B::KERNEL[..1] != "b"
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
("add", DType::F16) => contiguous::add::HALF,
|
||||
("sub", DType::F16) => contiguous::sub::HALF,
|
||||
("mul", DType::F16) => contiguous::mul::HALF,
|
||||
("div", DType::F16) => contiguous::div::HALF,
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("badd", DType::F32) => strided::add::FLOAT,
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
("bminimum", DType::F32) => strided::min::FLOAT,
|
||||
("bmaximum", DType::F32) => strided::max::FLOAT,
|
||||
("badd", DType::F16) => strided::add::HALF,
|
||||
("bsub", DType::F16) => strided::sub::HALF,
|
||||
("bmul", DType::F16) => strided::mul::HALF,
|
||||
("bdiv", DType::F16) => strided::div::HALF,
|
||||
("bminimum", DType::F16) => strided::min::HALF,
|
||||
("bmaximum", DType::F16) => strided::max::HALF,
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn where_cond(
|
||||
@ -1043,6 +982,111 @@ impl MetalStorage {
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn binary(
|
||||
&self,
|
||||
op: &'static str,
|
||||
rhs: &Self,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &op[..1] != "b"
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
|
||||
("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
|
||||
("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
|
||||
("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
|
||||
("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
|
||||
("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
|
||||
("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
|
||||
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
|
||||
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
|
||||
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
|
||||
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
|
||||
("div", DType::F16) => (contiguous::div::HALF, self.dtype),
|
||||
("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
|
||||
("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
|
||||
("le", DType::F16) => (contiguous::le::HALF, DType::U8),
|
||||
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
(buffer, dtype)
|
||||
} else {
|
||||
use candle_metal_kernels::binary::strided;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
|
||||
("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
|
||||
("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
|
||||
("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
|
||||
("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
|
||||
("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
|
||||
("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
|
||||
("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
|
||||
("le", DType::F32) => (strided::le::FLOAT, DType::U8),
|
||||
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
|
||||
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
|
||||
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
|
||||
("badd", DType::F16) => (strided::add::HALF, self.dtype),
|
||||
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
|
||||
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
|
||||
("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
|
||||
("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
|
||||
("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
|
||||
("eq", DType::F16) => (strided::eq::HALF, DType::U8),
|
||||
("ne", DType::F16) => (strided::ne::HALF, DType::U8),
|
||||
("le", DType::F16) => (strided::le::HALF, DType::U8),
|
||||
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
|
||||
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
|
||||
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
|
||||
(name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
let buffer = device.new_buffer(el_count, dtype, op)?;
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
(buffer, dtype)
|
||||
};
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
|
Reference in New Issue
Block a user