mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
use crate::{op::Op, Error, Result, Tensor, TensorId};
|
||||
use crate::op::{Op, ReduceOp};
|
||||
use crate::{Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
impl Tensor {
|
||||
@ -66,9 +67,8 @@ impl Tensor {
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Sum(node, _)
|
||||
| Op::Max(node, _)
|
||||
| Op::Min(node, _)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, _, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
@ -201,14 +201,15 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||
}
|
||||
Op::Sum(arg, _sum_dims) => {
|
||||
Op::Reduce(arg, ReduceOp::Sum, _) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&grad)?
|
||||
}
|
||||
Op::Max(_args, _sum_dims) => {
|
||||
Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
|
||||
Op::Reduce(_args, ReduceOp::Max, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "max" })
|
||||
}
|
||||
Op::Min(_args, _sum_dims) => {
|
||||
Op::Reduce(_args, ReduceOp::Min, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "min" })
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
|
Reference in New Issue
Block a user