From e4b0cc59f5651bb4370598a902e43cd8b0af5976 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 17 Dec 2023 22:32:25 +0100 Subject: [PATCH] Adding CMP --- candle-core/src/metal_backend.rs | 188 ++++++++++++++++---------- candle-metal-kernels/src/binary.metal | 35 +++-- candle-metal-kernels/src/lib.rs | 2 +- 3 files changed, 140 insertions(+), 85 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 047313d1..6f82b0cc 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -549,8 +549,16 @@ impl BackendStorage for MetalStorage { Ok(Self::new(buffer, device, dtype)) } - fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { - crate::bail!("cmp metal") + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { + let name = match op { + CmpOp::Eq => "eq", + CmpOp::Ne => "ne", + CmpOp::Le => "le", + CmpOp::Ge => "ge", + CmpOp::Lt => "lt", + CmpOp::Gt => "gt", + }; + self.binary(name, rhs, lhs_l, rhs_l) } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { @@ -711,76 +719,7 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let device = self.device(); - let dtype = self.dtype; - let shape = lhs_l.shape(); - let el_count = shape.elem_count(); - let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; - let command_buffer = device.command_buffer()?; - if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) - && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) - && &B::KERNEL[..1] != "b" - { - use candle_metal_kernels::binary::contiguous; - - let kernel_name = match (B::KERNEL, dtype) { - ("add", DType::F32) => contiguous::add::FLOAT, - ("sub", DType::F32) => contiguous::sub::FLOAT, - ("mul", DType::F32) => contiguous::mul::FLOAT, - ("div", DType::F32) => contiguous::div::FLOAT, - ("add", DType::F16) => contiguous::add::HALF, - ("sub", DType::F16) => contiguous::sub::HALF, - ("mul", DType::F16) => contiguous::mul::HALF, - ("div", DType::F16) => contiguous::div::HALF, - (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), - }; - candle_metal_kernels::call_binary_contiguous( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - el_count, - &self.buffer, - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; - } else { - use candle_metal_kernels::binary::strided; - - let kernel_name = match (B::KERNEL, dtype) { - ("badd", DType::F32) => strided::add::FLOAT, - ("bsub", DType::F32) => strided::sub::FLOAT, - ("bmul", DType::F32) => strided::mul::FLOAT, - ("bdiv", DType::F32) => strided::div::FLOAT, - ("bminimum", DType::F32) => strided::min::FLOAT, - ("bmaximum", DType::F32) => strided::max::FLOAT, - ("badd", DType::F16) => strided::add::HALF, - ("bsub", DType::F16) => strided::sub::HALF, - ("bmul", DType::F16) => strided::mul::HALF, - ("bdiv", DType::F16) => strided::div::HALF, - ("bminimum", DType::F16) => strided::min::HALF, - ("bmaximum", DType::F16) => strided::max::HALF, - (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), - }; - candle_metal_kernels::call_binary_strided( - &device.device, - &command_buffer, - &device.kernels, - kernel_name, - lhs_l.dims(), - &self.buffer, - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &rhs.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &buffer, - ) - .map_err(MetalError::from)?; - } - command_buffer.set_label("binary"); - Ok(Self::new(buffer, device.clone(), dtype)) + self.binary(B::KERNEL, rhs, lhs_l, rhs_l) } fn where_cond( @@ -1043,6 +982,111 @@ impl MetalStorage { pub fn buffer(&self) -> &Buffer { &self.buffer } + + pub fn binary( + &self, + op: &'static str, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device(); + let shape = lhs_l.shape(); + let el_count = shape.elem_count(); + let command_buffer = device.command_buffer()?; + let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) + && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + && &op[..1] != "b" + { + use candle_metal_kernels::binary::contiguous; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("add", DType::F32) => (contiguous::add::FLOAT, self.dtype), + ("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype), + ("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype), + ("div", DType::F32) => (contiguous::div::FLOAT, self.dtype), + ("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8), + ("le", DType::F32) => (contiguous::le::FLOAT, DType::U8), + ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), + ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), + ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), + ("div", DType::F16) => (contiguous::div::HALF, self.dtype), + ("eq", DType::F16) => (contiguous::eq::HALF, DType::U8), + ("ne", DType::F16) => (contiguous::ne::HALF, DType::U8), + ("le", DType::F16) => (contiguous::le::HALF, DType::U8), + ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), + ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), + ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + } else { + use candle_metal_kernels::binary::strided; + + let (kernel_name, dtype) = match (op, self.dtype) { + ("badd", DType::F32) => (strided::add::FLOAT, self.dtype), + ("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype), + ("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype), + ("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype), + ("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype), + ("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype), + ("eq", DType::F32) => (strided::eq::FLOAT, DType::U8), + ("ne", DType::F32) => (strided::ne::FLOAT, DType::U8), + ("le", DType::F32) => (strided::le::FLOAT, DType::U8), + ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), + ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), + ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), + ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), + ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), + ("bdiv", DType::F16) => (strided::div::HALF, self.dtype), + ("bminimum", DType::F16) => (strided::min::HALF, self.dtype), + ("bmaximum", DType::F16) => (strided::max::HALF, self.dtype), + ("eq", DType::F16) => (strided::eq::HALF, DType::U8), + ("ne", DType::F16) => (strided::ne::HALF, DType::U8), + ("le", DType::F16) => (strided::le::HALF, DType::U8), + ("lt", DType::F16) => (strided::lt::HALF, DType::U8), + ("ge", DType::F16) => (strided::ge::HALF, DType::U8), + ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + (name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"), + }; + let buffer = device.new_buffer(el_count, dtype, op)?; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &rhs.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + (buffer, dtype) + }; + command_buffer.set_label("binary"); + Ok(Self::new(buffer, device.clone(), dtype)) + } } impl BackendDevice for MetalDevice { diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index f13589c1..8c3b4a8c 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -25,15 +25,15 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const TYPENAME *left, \ device const TYPENAME *right, \ - device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - TYPENAME x = left[thread_position_in_grid]; \ - TYPENAME y = right[thread_position_in_grid]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + TYPENAME x = left[tid]; \ + TYPENAME y = right[tid]; \ + output[tid] = OUT_TYPENAME(FN); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -43,15 +43,15 @@ kernel void FN_NAME_STRIDED( \ constant size_t *right_strides, \ device const TYPENAME *left, \ device const TYPENAME *right, \ - device TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + device OUT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ - TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \ - output[thread_position_in_grid] = OUT_TYPENAME(FN); \ + TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \ + output[tid] = OUT_TYPENAME(FN); \ } #define BINARY_OP(FN, NAME) \ @@ -61,6 +61,10 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_BINARY_OP(FN, NAME) \ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); +#define BINARY_OP_OUT(NAME, FN) \ +BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ +BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); + BINARY_OP(x + y, add) BINARY_OP(x - y, sub) @@ -69,6 +73,13 @@ BINARY_OP(x / y, div) BINARY_OP(MIN(x, y), min) BINARY_OP(MAX(x, y), max) +BINARY_OP_OUT(eq, x == y) +BINARY_OP_OUT(ne, x != y) +BINARY_OP_OUT(le, x <= y) +BINARY_OP_OUT(lt, x < y) +BINARY_OP_OUT(ge, x >= y) +BINARY_OP_OUT(gt, x > y) + #if __METAL_VERSION__ >= 310 BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index c34e34fe..7485ba72 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -166,7 +166,7 @@ pub mod unary { ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { - ops!(add, sub, mul, div, min, max); + ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); } #[derive(thiserror::Error, Debug)]