Basic support for broadcasting backprop.

This commit is contained in:
laurent
2023-06-23 16:31:44 +01:00
parent 1936a1f0a3
commit d839d5d9fd

View File

@ -904,7 +904,7 @@ impl Tensor {
let lhs_sum_grad = grads.or_insert(lhs)?; let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?; *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?; *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
} }
Op::Mul(lhs, rhs) => { Op::Mul(lhs, rhs) => {
let lhs_grad = grad.mul(rhs)?; let lhs_grad = grad.mul(rhs)?;
@ -922,15 +922,17 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
} }
Op::BroadcastAdd(_lhs, _rhs) => { Op::BroadcastAdd(lhs, rhs) => {
return Err(Error::BackwardNotSupported { let lhs_sum_grad = grads.or_insert(lhs)?;
op: "broadcast_add", *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
}) let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?;
} }
Op::BroadcastSub(_lhs, _rhs) => { Op::BroadcastSub(lhs, rhs) => {
return Err(Error::BackwardNotSupported { let lhs_sum_grad = grads.or_insert(lhs)?;
op: "broadcast_sub", *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
}) let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
} }
Op::BroadcastMul(_lhs, _rhs) => { Op::BroadcastMul(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { return Err(Error::BackwardNotSupported {