mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
This commit is contained in:
@ -2,12 +2,31 @@ use crate::Tensor;
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CmpOp {
|
||||
Eq,
|
||||
Ne,
|
||||
Le,
|
||||
Ge,
|
||||
Lt,
|
||||
Gt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReduceOp {
|
||||
Sum,
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum Op {
|
||||
Add(Tensor, Tensor),
|
||||
Mul(Tensor, Tensor),
|
||||
Sub(Tensor, Tensor),
|
||||
Div(Tensor, Tensor),
|
||||
Cmp(Tensor, CmpOp),
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
@ -28,9 +47,6 @@ pub(crate) enum Op {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
},
|
||||
Sum(Tensor, Vec<usize>),
|
||||
Max(Tensor, Vec<usize>),
|
||||
Min(Tensor, Vec<usize>),
|
||||
ToDType(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Exp(Tensor),
|
||||
@ -356,10 +372,3 @@ impl UnaryOp for Relu {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReduceOp {
|
||||
Sum,
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
Reference in New Issue
Block a user