Basic support for broadcasting backprop.

This commit is contained in:
laurent
2023-06-23 16:31:44 +01:00
parent 1936a1f0a3
commit d839d5d9fd

View File

@ -904,7 +904,7 @@ impl Tensor {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
}
Op::Mul(lhs, rhs) => {
let lhs_grad = grad.mul(rhs)?;
@ -922,15 +922,17 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::BroadcastAdd(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_add",
})
Op::BroadcastAdd(lhs, rhs) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?;
}
Op::BroadcastSub(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_sub",
})
Op::BroadcastSub(lhs, rhs) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
}
Op::BroadcastMul(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {