From d839d5d9fda698f8eaef066c3219cd9ec5bc1c88 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 23 Jun 2023 16:31:44 +0100 Subject: [PATCH] Basic support for broadcasting backprop. --- src/tensor.rs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index 2dbff9be..7a70efba 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -904,7 +904,7 @@ impl Tensor { let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&grad)?; let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?; + *rhs_sum_grad = rhs_sum_grad.sub(&grad)?; } Op::Mul(lhs, rhs) => { let lhs_grad = grad.mul(rhs)?; @@ -922,15 +922,17 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::BroadcastAdd(_lhs, _rhs) => { - return Err(Error::BackwardNotSupported { - op: "broadcast_add", - }) + Op::BroadcastAdd(lhs, rhs) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?; } - Op::BroadcastSub(_lhs, _rhs) => { - return Err(Error::BackwardNotSupported { - op: "broadcast_sub", - }) + Op::BroadcastSub(lhs, rhs) => { + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?; } Op::BroadcastMul(_lhs, _rhs) => { return Err(Error::BackwardNotSupported {