mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Op refactor (#208)
* Add the binary and unary op enums to factorize some code. * Bugfix.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::op::{Op, ReduceOp};
|
||||
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::{Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -39,10 +39,7 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
@ -74,17 +71,8 @@ impl Tensor {
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
| Op::Gelu(node)
|
||||
| Op::Relu(node)
|
||||
| Op::Elu(node, _)
|
||||
| Op::Exp(node)
|
||||
| Op::Log(node)
|
||||
| Op::Sin(node)
|
||||
| Op::Cos(node)
|
||||
| Op::Abs(node)
|
||||
| Op::Neg(node) => {
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
@ -118,19 +106,19 @@ impl Tensor {
|
||||
// this is out of scope.
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Add(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Sub(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Sub) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
||||
}
|
||||
Op::Mul(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Mul) => {
|
||||
let lhs_grad = grad.mul(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
@ -138,7 +126,7 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Div(lhs, rhs) => {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Div) => {
|
||||
let lhs_grad = grad.div(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
@ -221,24 +209,26 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Log(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Log) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
}
|
||||
Op::Sin(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Sin) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
|
||||
}
|
||||
Op::Cos(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Cos) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
|
||||
Op::Exp(arg) => {
|
||||
Op::Unary(_, UnaryOp::Abs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "abs" })
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Exp) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
}
|
||||
Op::Neg(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Neg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
@ -276,15 +266,19 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
||||
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
|
||||
Op::Unary(_, UnaryOp::Gelu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "gelu" })
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Relu) => {
|
||||
return Err(Error::BackwardNotSupported { op: "relu" })
|
||||
}
|
||||
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
|
||||
Op::Sqr(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Sqrt(arg) => {
|
||||
Op::Unary(arg, UnaryOp::Sqrt) => {
|
||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
|
Reference in New Issue
Block a user