mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -80,12 +81,38 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reduce_op(
|
||||
pub(crate) fn cmp(
|
||||
&self,
|
||||
op: crate::op::ReduceOp,
|
||||
layout: &Layout,
|
||||
s: &[usize],
|
||||
op: CmpOp,
|
||||
rhs: &Self,
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "cmp")?;
|
||||
self.same_dtype(rhs, "cmp")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "cmp",
|
||||
}
|
||||
.bt())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
|
Reference in New Issue
Block a user