From e9c052bf94521b418852a1c5231c12ddce99a78f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 20 Jul 2023 10:40:31 +0200 Subject: [PATCH] Add the comparison operations. (#207) * Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations. --- candle-core/src/backend.rs | 5 ++- candle-core/src/backprop.rs | 15 ++++--- candle-core/src/cpu_backend.rs | 61 +++++++++++++++++++++++++-- candle-core/src/cuda_backend.rs | 20 ++++----- candle-core/src/dummy_cuda_backend.rs | 7 ++- candle-core/src/op.rs | 29 ++++++++----- candle-core/src/storage.rs | 37 +++++++++++++--- candle-core/src/tensor.rs | 45 ++++++++++++++++++-- 8 files changed, 178 insertions(+), 41 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 018279b3..307868dd 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -1,3 +1,4 @@ +use crate::op::{CmpOp, ReduceOp}; use crate::{CpuStorage, DType, Layout, Result, Shape}; pub(crate) trait BackendStorage: Sized { @@ -16,7 +17,9 @@ pub(crate) trait BackendStorage: Sized { fn elu(&self, _: &Layout, _: f64) -> Result; - fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result; + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result; + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result; fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 3de11d35..4d968e7f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -1,4 +1,5 @@ -use crate::{op::Op, Error, Result, Tensor, TensorId}; +use crate::op::{Op, ReduceOp}; +use crate::{Error, Result, Tensor, TensorId}; use std::collections::HashMap; impl Tensor { @@ -66,9 +67,8 @@ impl Tensor { } Op::Reshape(node) | Op::Broadcast(node) - | Op::Sum(node, _) - | Op::Max(node, _) - | Op::Min(node, _) + | Op::Cmp(node, _) + | Op::Reduce(node, _, _) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -201,14 +201,15 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.broadcast_add(&arg_grad)? } - Op::Sum(arg, _sum_dims) => { + Op::Reduce(arg, ReduceOp::Sum, _) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.broadcast_add(&grad)? } - Op::Max(_args, _sum_dims) => { + Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }), + Op::Reduce(_args, ReduceOp::Max, _) => { return Err(Error::BackwardNotSupported { op: "max" }) } - Op::Min(_args, _sum_dims) => { + Op::Reduce(_args, ReduceOp::Min, _) => { return Err(Error::BackwardNotSupported { op: "min" }) } Op::ToDType(arg) => { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 925ca112..b12e0702 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{BinaryOp, ReduceOp, UnaryOp}; +use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp}; use crate::{DType, Error, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; @@ -62,6 +62,57 @@ trait Map2 { } } +trait Map2U8 { + const OP: &'static str; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; + + fn map( + &self, + v1: &CpuStorage, + l1: &Layout, + v2: &CpuStorage, + l2: &Layout, + ) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +struct Cmp(CmpOp); +impl Map2U8 for Cmp { + const OP: &'static str = "cmp"; + #[inline(always)] + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result> { + let dst = match self.0 { + CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)), + CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)), + CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)), + CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)), + CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)), + CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)), + }; + Ok(dst) + } +} + struct WCond<'a>(&'a [u32], &'a Layout); impl<'a> Map2 for WCond<'a> { @@ -269,13 +320,13 @@ fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( } // This function maps over two strided index sequences. -fn binary_map T>( +fn binary_map U>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], rhs: &[T], mut f: F, -) -> Vec { +) -> Vec { match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] .iter() @@ -1064,6 +1115,10 @@ impl BackendStorage for CpuStorage { .map(self, layout) } + fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { + Cmp(op).map(self, lhs_l, rhs, rhs_l) + } + fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { // [self] stores data in a contiguous way starting at offset 0. match self { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index b74137f3..9e47c133 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,4 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; +use crate::op::{CmpOp, ReduceOp}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; @@ -515,7 +516,7 @@ impl<'a> Map1 for Sum<'a> { } } -struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp); +struct FastReduce<'a>(&'a [usize], ReduceOp); impl<'a> Map1 for FastReduce<'a> { fn f( &self, @@ -558,9 +559,9 @@ impl<'a> Map1 for FastReduce<'a> { .w()?; let src = &src.slice(layout.start_offset()..); let name = match self.1 { - crate::op::ReduceOp::Sum => "fast_sum", - crate::op::ReduceOp::Min => "fast_min", - crate::op::ReduceOp::Max => "fast_max", + ReduceOp::Sum => "fast_sum", + ReduceOp::Min => "fast_min", + ReduceOp::Max => "fast_max", }; let func = dev.get_or_load_func(&kernel_name::(name), kernels::REDUCE)?; // SAFETY: filled in by the follow up kernel. @@ -961,17 +962,16 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn reduce_op( - &self, - op: crate::op::ReduceOp, - layout: &Layout, - sum_dims: &[usize], - ) -> Result { + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { let device = self.device().clone(); let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + Err(CudaError::InternalError("TODO: implement cmp").into()) + } + fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into()) } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index f7cf8ab8..942e82ed 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +use crate::op::{CmpOp, ReduceOp}; use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; #[derive(Debug, Clone)] @@ -40,7 +41,11 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result { + fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index c5ff8179..ece6969c 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -2,12 +2,31 @@ use crate::Tensor; use half::{bf16, f16}; use num_traits::float::Float; +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum CmpOp { + Eq, + Ne, + Le, + Ge, + Lt, + Gt, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReduceOp { + Sum, + Min, + Max, +} + #[derive(Clone)] pub(crate) enum Op { Add(Tensor, Tensor), Mul(Tensor, Tensor), Sub(Tensor, Tensor), Div(Tensor, Tensor), + Cmp(Tensor, CmpOp), + Reduce(Tensor, ReduceOp, Vec), Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), WhereCond(Tensor, Tensor, Tensor), @@ -28,9 +47,6 @@ pub(crate) enum Op { mul: f64, add: f64, }, - Sum(Tensor, Vec), - Max(Tensor, Vec), - Min(Tensor, Vec), ToDType(Tensor), Broadcast(Tensor), Exp(Tensor), @@ -356,10 +372,3 @@ impl UnaryOp for Relu { v } } - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ReduceOp { - Sum, - Min, - Max, -} diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index e689905e..fb72322c 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,5 +1,6 @@ use crate::backend::BackendStorage; -use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; +use crate::op::{self, CmpOp, ReduceOp}; +use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -80,12 +81,38 @@ impl Storage { } } - pub(crate) fn reduce_op( + pub(crate) fn cmp( &self, - op: crate::op::ReduceOp, - layout: &Layout, - s: &[usize], + op: CmpOp, + rhs: &Self, + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { + self.same_device(rhs, "cmp")?; + self.same_dtype(rhs, "cmp")?; + match (self, rhs) { + (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { + let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(lhs), Self::Cuda(rhs)) => { + let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?; + Ok(Self::Cuda(storage)) + } + (lhs, rhs) => { + // Should not happen because of the same device check above but we're defensive + // anyway. + Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "cmp", + } + .bt()) + } + } + } + + pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result { match self { Storage::Cpu(storage) => { let storage = storage.reduce_op(op, layout, s)?; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 276a522e..d6c3e9cb 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{Op, ReduceOp}; +use crate::op::{CmpOp, Op, ReduceOp}; use crate::shape::{Dim, Dims}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -634,7 +634,7 @@ impl Tensor { .storage() .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?; let op = if self.track_op() { - Some(Op::Max(self.clone(), max_dims.to_vec())) + Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec())) } else { None }; @@ -656,7 +656,7 @@ impl Tensor { .storage() .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?; let op = if self.track_op() { - Some(Op::Min(self.clone(), min_dims.to_vec())) + Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec())) } else { None }; @@ -678,7 +678,7 @@ impl Tensor { .storage() .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?; let op = if self.track_op() { - Some(Op::Sum(self.clone(), sum_dims.to_vec())) + Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec())) } else { None }; @@ -748,6 +748,43 @@ impl Tensor { self.min(dims) } + pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result { + let shape = self.same_shape_binary_op(rhs, "cmp")?; + let storage = self + .storage() + .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; + let op = if self.track_op() { + Some(Op::Cmp(self.clone(), op)) + } else { + None + }; + Ok(from_storage(storage, shape.dims(), op, false)) + } + + pub fn eq(&self, rhs: &Self) -> Result { + self.cmp(rhs, CmpOp::Eq) + } + + pub fn ne(&self, rhs: &Self) -> Result { + self.cmp(rhs, CmpOp::Ne) + } + + pub fn lt(&self, rhs: &Self) -> Result { + self.cmp(rhs, CmpOp::Lt) + } + + pub fn gt(&self, rhs: &Self) -> Result { + self.cmp(rhs, CmpOp::Gt) + } + + pub fn ge(&self, rhs: &Self) -> Result { + self.cmp(rhs, CmpOp::Ge) + } + + pub fn le(&self, rhs: &Self) -> Result { + self.cmp(rhs, CmpOp::Le) + } + /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { let (c_out, c_in_k, k_size) = kernel.shape().r3()?;