mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{Op, ReduceOp};
|
||||
use crate::op::{CmpOp, Op, ReduceOp};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -634,7 +634,7 @@ impl Tensor {
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Max(self.clone(), max_dims.to_vec()))
|
||||
Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -656,7 +656,7 @@ impl Tensor {
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Min(self.clone(), min_dims.to_vec()))
|
||||
Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -678,7 +678,7 @@ impl Tensor {
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -748,6 +748,43 @@ impl Tensor {
|
||||
self.min(dims)
|
||||
}
|
||||
|
||||
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, "cmp")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Cmp(self.clone(), op))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape.dims(), op, false))
|
||||
}
|
||||
|
||||
pub fn eq(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Eq)
|
||||
}
|
||||
|
||||
pub fn ne(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ne)
|
||||
}
|
||||
|
||||
pub fn lt(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Lt)
|
||||
}
|
||||
|
||||
pub fn gt(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Gt)
|
||||
}
|
||||
|
||||
pub fn ge(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Ge)
|
||||
}
|
||||
|
||||
pub fn le(&self, rhs: &Self) -> Result<Self> {
|
||||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
|
Reference in New Issue
Block a user