Add binary and ternary custom ops. (#217)

This commit is contained in:
Laurent Mazare
2023-07-21 18:29:50 +02:00
committed by GitHub
parent 4a100875bf
commit 5cc843550d
5 changed files with 209 additions and 10 deletions

View File

@ -38,7 +38,7 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
Op::WhereCond(t1, t2, t3) => {
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
@ -52,6 +52,7 @@ impl Tensor {
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
@ -321,9 +322,37 @@ impl Tensor {
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
let sum_grad = grads.or_insert(arg)?;
let arg_grad = c.bwd(arg, node, &grad)?;
*sum_grad = sum_grad.add(&arg_grad)?
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
}
Op::CustomOp2(arg1, arg2, c) => {
let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
if let Some(arg_grad1) = arg_grad1 {
let sum_grad = grads.or_insert(arg1)?;
*sum_grad = sum_grad.add(&arg_grad1)?
}
if let Some(arg_grad2) = arg_grad2 {
let sum_grad = grads.or_insert(arg2)?;
*sum_grad = sum_grad.add(&arg_grad2)?
}
}
Op::CustomOp3(arg1, arg2, arg3, c) => {
let (arg_grad1, arg_grad2, arg_grad3) =
c.bwd(arg1, arg2, arg3, node, &grad)?;
if let Some(arg_grad1) = arg_grad1 {
let sum_grad = grads.or_insert(arg1)?;
*sum_grad = sum_grad.add(&arg_grad1)?
}
if let Some(arg_grad2) = arg_grad2 {
let sum_grad = grads.or_insert(arg2)?;
*sum_grad = sum_grad.add(&arg_grad2)?
}
if let Some(arg_grad3) = arg_grad3 {
let sum_grad = grads.or_insert(arg3)?;
*sum_grad = sum_grad.add(&arg_grad3)?
}
}
Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;