mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
This commit is contained in:
@ -1,3 +1,4 @@
|
|||||||
|
use crate::op::{CmpOp, ReduceOp};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
|
|
||||||
pub(crate) trait BackendStorage: Sized {
|
pub(crate) trait BackendStorage: Sized {
|
||||||
@ -16,7 +17,9 @@ pub(crate) trait BackendStorage: Sized {
|
|||||||
|
|
||||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||||
|
|
||||||
fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
|
||||||
|
|
||||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
|
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use crate::{op::Op, Error, Result, Tensor, TensorId};
|
use crate::op::{Op, ReduceOp};
|
||||||
|
use crate::{Error, Result, Tensor, TensorId};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
@ -66,9 +67,8 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::Broadcast(node)
|
| Op::Broadcast(node)
|
||||||
| Op::Sum(node, _)
|
| Op::Cmp(node, _)
|
||||||
| Op::Max(node, _)
|
| Op::Reduce(node, _, _)
|
||||||
| Op::Min(node, _)
|
|
||||||
| Op::ToDType(node)
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
@ -201,14 +201,15 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Sum(arg, _sum_dims) => {
|
Op::Reduce(arg, ReduceOp::Sum, _) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.broadcast_add(&grad)?
|
*sum_grad = sum_grad.broadcast_add(&grad)?
|
||||||
}
|
}
|
||||||
Op::Max(_args, _sum_dims) => {
|
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
||||||
|
Op::Reduce(_args, ReduceOp::Max, _) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "max" })
|
return Err(Error::BackwardNotSupported { op: "max" })
|
||||||
}
|
}
|
||||||
Op::Min(_args, _sum_dims) => {
|
Op::Reduce(_args, ReduceOp::Min, _) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "min" })
|
return Err(Error::BackwardNotSupported { op: "min" })
|
||||||
}
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{BinaryOp, ReduceOp, UnaryOp};
|
use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp};
|
||||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
@ -62,6 +62,57 @@ trait Map2 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait Map2U8 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
v1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<CpuStorage> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Cmp(CmpOp);
|
||||||
|
impl Map2U8 for Cmp {
|
||||||
|
const OP: &'static str = "cmp";
|
||||||
|
#[inline(always)]
|
||||||
|
fn f<T: WithDType>(
|
||||||
|
&self,
|
||||||
|
lhs: &[T],
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs: &[T],
|
||||||
|
rhs_l: &Layout,
|
||||||
|
) -> Result<Vec<u8>> {
|
||||||
|
let dst = match self.0 {
|
||||||
|
CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
|
||||||
|
CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
|
||||||
|
CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
|
||||||
|
CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
|
||||||
|
CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
|
||||||
|
CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
|
||||||
|
};
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct WCond<'a>(&'a [u32], &'a Layout);
|
struct WCond<'a>(&'a [u32], &'a Layout);
|
||||||
|
|
||||||
impl<'a> Map2 for WCond<'a> {
|
impl<'a> Map2 for WCond<'a> {
|
||||||
@ -269,13 +320,13 @@ fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This function maps over two strided index sequences.
|
// This function maps over two strided index sequences.
|
||||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
rhs: &[T],
|
rhs: &[T],
|
||||||
mut f: F,
|
mut f: F,
|
||||||
) -> Vec<T> {
|
) -> Vec<U> {
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||||
.iter()
|
.iter()
|
||||||
@ -1064,6 +1115,10 @@ impl BackendStorage for CpuStorage {
|
|||||||
.map(self, layout)
|
.map(self, layout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||||
|
Cmp(op).map(self, lhs_l, rhs, rhs_l)
|
||||||
|
}
|
||||||
|
|
||||||
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
// [self] stores data in a contiguous way starting at offset 0.
|
// [self] stores data in a contiguous way starting at offset 0.
|
||||||
match self {
|
match self {
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
|
use crate::op::{CmpOp, ReduceOp};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
@ -515,7 +516,7 @@ impl<'a> Map1 for Sum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp);
|
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||||
impl<'a> Map1 for FastReduce<'a> {
|
impl<'a> Map1 for FastReduce<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
@ -558,9 +559,9 @@ impl<'a> Map1 for FastReduce<'a> {
|
|||||||
.w()?;
|
.w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let name = match self.1 {
|
let name = match self.1 {
|
||||||
crate::op::ReduceOp::Sum => "fast_sum",
|
ReduceOp::Sum => "fast_sum",
|
||||||
crate::op::ReduceOp::Min => "fast_min",
|
ReduceOp::Min => "fast_min",
|
||||||
crate::op::ReduceOp::Max => "fast_max",
|
ReduceOp::Max => "fast_max",
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||||
// SAFETY: filled in by the follow up kernel.
|
// SAFETY: filled in by the follow up kernel.
|
||||||
@ -961,17 +962,16 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_op(
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
&self,
|
|
||||||
op: crate::op::ReduceOp,
|
|
||||||
layout: &Layout,
|
|
||||||
sum_dims: &[usize],
|
|
||||||
) -> Result<Self> {
|
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
|
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
|
Err(CudaError::InternalError("TODO: implement cmp").into())
|
||||||
|
}
|
||||||
|
|
||||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
use crate::op::{CmpOp, ReduceOp};
|
||||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -40,7 +41,11 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,12 +2,31 @@ use crate::Tensor;
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum CmpOp {
|
||||||
|
Eq,
|
||||||
|
Ne,
|
||||||
|
Le,
|
||||||
|
Ge,
|
||||||
|
Lt,
|
||||||
|
Gt,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ReduceOp {
|
||||||
|
Sum,
|
||||||
|
Min,
|
||||||
|
Max,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) enum Op {
|
pub(crate) enum Op {
|
||||||
Add(Tensor, Tensor),
|
Add(Tensor, Tensor),
|
||||||
Mul(Tensor, Tensor),
|
Mul(Tensor, Tensor),
|
||||||
Sub(Tensor, Tensor),
|
Sub(Tensor, Tensor),
|
||||||
Div(Tensor, Tensor),
|
Div(Tensor, Tensor),
|
||||||
|
Cmp(Tensor, CmpOp),
|
||||||
|
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
WhereCond(Tensor, Tensor, Tensor),
|
WhereCond(Tensor, Tensor, Tensor),
|
||||||
@ -28,9 +47,6 @@ pub(crate) enum Op {
|
|||||||
mul: f64,
|
mul: f64,
|
||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
Sum(Tensor, Vec<usize>),
|
|
||||||
Max(Tensor, Vec<usize>),
|
|
||||||
Min(Tensor, Vec<usize>),
|
|
||||||
ToDType(Tensor),
|
ToDType(Tensor),
|
||||||
Broadcast(Tensor),
|
Broadcast(Tensor),
|
||||||
Exp(Tensor),
|
Exp(Tensor),
|
||||||
@ -356,10 +372,3 @@ impl UnaryOp for Relu {
|
|||||||
v
|
v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
||||||
pub enum ReduceOp {
|
|
||||||
Sum,
|
|
||||||
Min,
|
|
||||||
Max,
|
|
||||||
}
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
use crate::op::{self, CmpOp, ReduceOp};
|
||||||
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
@ -80,12 +81,38 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn reduce_op(
|
pub(crate) fn cmp(
|
||||||
&self,
|
&self,
|
||||||
op: crate::op::ReduceOp,
|
op: CmpOp,
|
||||||
layout: &Layout,
|
rhs: &Self,
|
||||||
s: &[usize],
|
lhs_layout: &Layout,
|
||||||
|
rhs_layout: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
self.same_device(rhs, "cmp")?;
|
||||||
|
self.same_dtype(rhs, "cmp")?;
|
||||||
|
match (self, rhs) {
|
||||||
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
|
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
(lhs, rhs) => {
|
||||||
|
// Should not happen because of the same device check above but we're defensive
|
||||||
|
// anyway.
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "cmp",
|
||||||
|
}
|
||||||
|
.bt())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
let storage = storage.reduce_op(op, layout, s)?;
|
let storage = storage.reduce_op(op, layout, s)?;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{Op, ReduceOp};
|
use crate::op::{CmpOp, Op, ReduceOp};
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
@ -634,7 +634,7 @@ impl Tensor {
|
|||||||
.storage()
|
.storage()
|
||||||
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Max(self.clone(), max_dims.to_vec()))
|
Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -656,7 +656,7 @@ impl Tensor {
|
|||||||
.storage()
|
.storage()
|
||||||
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Min(self.clone(), min_dims.to_vec()))
|
Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -678,7 +678,7 @@ impl Tensor {
|
|||||||
.storage()
|
.storage()
|
||||||
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -748,6 +748,43 @@ impl Tensor {
|
|||||||
self.min(dims)
|
self.min(dims)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
||||||
|
let shape = self.same_shape_binary_op(rhs, "cmp")?;
|
||||||
|
let storage = self
|
||||||
|
.storage()
|
||||||
|
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Cmp(self.clone(), op))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(from_storage(storage, shape.dims(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn eq(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.cmp(rhs, CmpOp::Eq)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ne(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.cmp(rhs, CmpOp::Ne)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn lt(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.cmp(rhs, CmpOp::Lt)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gt(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.cmp(rhs, CmpOp::Gt)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ge(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.cmp(rhs, CmpOp::Ge)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn le(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
self.cmp(rhs, CmpOp::Le)
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies a 1D convolution over the input tensor.
|
/// Applies a 1D convolution over the input tensor.
|
||||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||||
|
Reference in New Issue
Block a user