mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add binary and ternary custom ops. (#217)
This commit is contained in:
@ -38,7 +38,7 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::WhereCond(t1, t2, t3) => {
|
||||
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
@ -52,6 +52,7 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
@ -321,9 +322,37 @@ impl Tensor {
|
||||
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::CustomOp1(arg, c) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let arg_grad = c.bwd(arg, node, &grad)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
}
|
||||
Op::CustomOp2(arg1, arg2, c) => {
|
||||
let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
|
||||
if let Some(arg_grad1) = arg_grad1 {
|
||||
let sum_grad = grads.or_insert(arg1)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad1)?
|
||||
}
|
||||
if let Some(arg_grad2) = arg_grad2 {
|
||||
let sum_grad = grads.or_insert(arg2)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad2)?
|
||||
}
|
||||
}
|
||||
Op::CustomOp3(arg1, arg2, arg3, c) => {
|
||||
let (arg_grad1, arg_grad2, arg_grad3) =
|
||||
c.bwd(arg1, arg2, arg3, node, &grad)?;
|
||||
if let Some(arg_grad1) = arg_grad1 {
|
||||
let sum_grad = grads.or_insert(arg1)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad1)?
|
||||
}
|
||||
if let Some(arg_grad2) = arg_grad2 {
|
||||
let sum_grad = grads.or_insert(arg2)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad2)?
|
||||
}
|
||||
if let Some(arg_grad3) = arg_grad3 {
|
||||
let sum_grad = grads.or_insert(arg3)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad3)?
|
||||
}
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
|
Reference in New Issue
Block a user