From 69f91b36f91a95929d29ca99a436667e7fd51ee4 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 16:35:10 +0100 Subject: [PATCH] More backprop support for broadcasting ops. --- src/tensor.rs | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 7a70efba..50b8cadc 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -934,15 +934,21 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?; } - Op::BroadcastMul(_lhs, _rhs) => { - return Err(Error::BackwardNotSupported { - op: "broadcast_mul", - }) + Op::BroadcastMul(lhs, rhs) => { + let lhs_grad = grad.broadcast_mul(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; + let rhs_grad = grad.broadcast_mul(lhs)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; } - Op::BroadcastDiv(_lhs, _rhs) => { - return Err(Error::BackwardNotSupported { - op: "broadcast_div", - }) + Op::BroadcastDiv(lhs, rhs) => { + let lhs_grad = grad.broadcast_div(rhs)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; + let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; } Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) @@ -966,9 +972,8 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Neg(arg) => { - let arg_grad = grad.neg()?; let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&arg_grad)? + *sum_grad = sum_grad.sub(&grad)? } Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "reshape" }),