Add the comparison operations. (#207)

* Add the comparison operations.

* Add the helper functions on the tensor side.

* More cmp operations.

* Cpu implementation for the comparison operations.
This commit is contained in:
Laurent Mazare
2023-07-20 10:40:31 +02:00
committed by GitHub
parent dc416243a3
commit e9c052bf94
8 changed files with 178 additions and 41 deletions

View File

@ -1,5 +1,6 @@
use crate::backend::BackendStorage;
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
use crate::op::{self, CmpOp, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -80,12 +81,38 @@ impl Storage {
}
}
pub(crate) fn reduce_op(
pub(crate) fn cmp(
&self,
op: crate::op::ReduceOp,
layout: &Layout,
s: &[usize],
op: CmpOp,
rhs: &Self,
lhs_layout: &Layout,
rhs_layout: &Layout,
) -> Result<Self> {
self.same_device(rhs, "cmp")?;
self.same_dtype(rhs, "cmp")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "cmp",
}
.bt())
}
}
}
pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.reduce_op(op, layout, s)?;