mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, ReduceOp, UnaryOp};
|
||||
use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
@ -62,6 +62,57 @@ trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
||||
fn map(
|
||||
&self,
|
||||
v1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
v2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<CpuStorage> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Cmp(CmpOp);
|
||||
impl Map2U8 for Cmp {
|
||||
const OP: &'static str = "cmp";
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(
|
||||
&self,
|
||||
lhs: &[T],
|
||||
lhs_l: &Layout,
|
||||
rhs: &[T],
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<u8>> {
|
||||
let dst = match self.0 {
|
||||
CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
|
||||
CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
|
||||
CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
|
||||
CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
|
||||
CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
|
||||
CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
|
||||
};
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct WCond<'a>(&'a [u32], &'a Layout);
|
||||
|
||||
impl<'a> Map2 for WCond<'a> {
|
||||
@ -269,13 +320,13 @@ fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<T> {
|
||||
) -> Vec<U> {
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
@ -1064,6 +1115,10 @@ impl BackendStorage for CpuStorage {
|
||||
.map(self, layout)
|
||||
}
|
||||
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
Cmp(op).map(self, lhs_l, rhs, rhs_l)
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way starting at offset 0.
|
||||
match self {
|
||||
|
Reference in New Issue
Block a user