Custom ops with a single argument (#214)

* Add the CustomOp1 trait.

* Add an example of custom op.

* Polish the custom op example.

* Add some backward pass test for custom ops.
This commit is contained in:
Laurent Mazare
2023-07-21 16:18:05 +02:00
committed by GitHub
parent b02229ce92
commit a6bcdfb269
8 changed files with 241 additions and 18 deletions

View File

@ -86,7 +86,8 @@ impl Tensor {
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Unary(node, _)
| Op::Elu(node, _) => {
| Op::Elu(node, _)
| Op::CustomOp1(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@ -319,6 +320,11 @@ impl Tensor {
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
let sum_grad = grads.or_insert(arg)?;
let arg_grad = c.bwd(arg, node, &grad)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;