Add the comparison operations. (#207)

* Add the comparison operations.

* Add the helper functions on the tensor side.

* More cmp operations.

* Cpu implementation for the comparison operations.
This commit is contained in:
Laurent Mazare
2023-07-20 10:40:31 +02:00
committed by GitHub
parent dc416243a3
commit e9c052bf94
8 changed files with 178 additions and 41 deletions

View File

@ -2,12 +2,31 @@ use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum CmpOp {
Eq,
Ne,
Le,
Ge,
Lt,
Gt,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Min,
Max,
}
#[derive(Clone)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
Cmp(Tensor, CmpOp),
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
WhereCond(Tensor, Tensor, Tensor),
@ -28,9 +47,6 @@ pub(crate) enum Op {
mul: f64,
add: f64,
},
Sum(Tensor, Vec<usize>),
Max(Tensor, Vec<usize>),
Min(Tensor, Vec<usize>),
ToDType(Tensor),
Broadcast(Tensor),
Exp(Tensor),
@ -356,10 +372,3 @@ impl UnaryOp for Relu {
v
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Min,
Max,
}