Add the comparison operations. (#207)

* Add the comparison operations.

* Add the helper functions on the tensor side.

* More cmp operations.

* Cpu implementation for the comparison operations.
This commit is contained in:
Laurent Mazare
2023-07-20 10:40:31 +02:00
committed by GitHub
parent dc416243a3
commit e9c052bf94
8 changed files with 178 additions and 41 deletions

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{Op, ReduceOp};
use crate::op::{CmpOp, Op, ReduceOp};
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -634,7 +634,7 @@ impl Tensor {
.storage()
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
let op = if self.track_op() {
Some(Op::Max(self.clone(), max_dims.to_vec()))
Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec()))
} else {
None
};
@ -656,7 +656,7 @@ impl Tensor {
.storage()
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
let op = if self.track_op() {
Some(Op::Min(self.clone(), min_dims.to_vec()))
Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec()))
} else {
None
};
@ -678,7 +678,7 @@ impl Tensor {
.storage()
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec()))
} else {
None
};
@ -748,6 +748,43 @@ impl Tensor {
self.min(dims)
}
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, "cmp")?;
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
let op = if self.track_op() {
Some(Op::Cmp(self.clone(), op))
} else {
None
};
Ok(from_storage(storage, shape.dims(), op, false))
}
pub fn eq(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Eq)
}
pub fn ne(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Ne)
}
pub fn lt(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Lt)
}
pub fn gt(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Gt)
}
pub fn ge(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Ge)
}
pub fn le(&self, rhs: &Self) -> Result<Self> {
self.cmp(rhs, CmpOp::Le)
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;