mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values. * Add some time measurement.
This commit is contained in:
@ -4,6 +4,7 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -776,8 +777,15 @@ impl Tensor {
|
||||
/// comparison operation is specified by the `op` argument.
|
||||
///
|
||||
/// The returned tensor has the same shape as the original tensors and uses `u8` elements.
|
||||
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, "cmp")?;
|
||||
pub fn cmp<T: TensorOrScalar>(&self, rhs: T, op: CmpOp) -> Result<Self> {
|
||||
let rhs = match rhs.to_tensor_scalar()? {
|
||||
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
|
||||
crate::scalar::TensorScalar::Scalar(rhs) => rhs
|
||||
.to_dtype(self.dtype())?
|
||||
.to_device(self.device())?
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, "cmp")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||
@ -786,36 +794,36 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Element-wise equality.
|
||||
pub fn eq(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn eq<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Eq)
|
||||
}
|
||||
|
||||
/// Element-wise non-equality.
|
||||
pub fn ne(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn ne<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ne)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with lower-than, the returned tensor uses value 1 where `self <
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn lt(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn lt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Lt)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with greater-than, the returned tensor uses value 1 where `self >
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn gt(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn gt<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Gt)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with greater-equal, the returned tensor uses value 1 where `self >=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn ge(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn ge<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ge)
|
||||
}
|
||||
|
||||
/// Element-wise comparison with lower-equal, the returned tensor uses value 1 where `self <=
|
||||
/// rhs` and 0 otherwise.
|
||||
pub fn le(&self, rhs: &Self) -> Result<Self> {
|
||||
pub fn le<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user