Add the comparison operations. (#207)

* Add the comparison operations.

* Add the helper functions on the tensor side.

* More cmp operations.

* Cpu implementation for the comparison operations.
This commit is contained in:
Laurent Mazare
2023-07-20 10:40:31 +02:00
committed by GitHub
parent dc416243a3
commit e9c052bf94
8 changed files with 178 additions and 41 deletions

View File

@ -1,4 +1,5 @@
use crate::{op::Op, Error, Result, Tensor, TensorId};
use crate::op::{Op, ReduceOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
impl Tensor {
@ -66,9 +67,8 @@ impl Tensor {
}
Op::Reshape(node)
| Op::Broadcast(node)
| Op::Sum(node, _)
| Op::Max(node, _)
| Op::Min(node, _)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@ -201,14 +201,15 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
}
Op::Sum(arg, _sum_dims) => {
Op::Reduce(arg, ReduceOp::Sum, _) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&grad)?
}
Op::Max(_args, _sum_dims) => {
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
Op::Reduce(_args, ReduceOp::Max, _) => {
return Err(Error::BackwardNotSupported { op: "max" })
}
Op::Min(_args, _sum_dims) => {
Op::Reduce(_args, ReduceOp::Min, _) => {
return Err(Error::BackwardNotSupported { op: "min" })
}
Op::ToDType(arg) => {