More backprop support for broadcasting ops.

This commit is contained in:
laurent
2023-06-23 16:35:10 +01:00
parent d839d5d9fd
commit 69f91b36f9

View File

@ -934,15 +934,21 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?; *rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
} }
Op::BroadcastMul(_lhs, _rhs) => { Op::BroadcastMul(lhs, rhs) => {
return Err(Error::BackwardNotSupported { let lhs_grad = grad.broadcast_mul(rhs)?;
op: "broadcast_mul", let lhs_sum_grad = grads.or_insert(lhs)?;
}) *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
let rhs_grad = grad.broadcast_mul(lhs)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
} }
Op::BroadcastDiv(_lhs, _rhs) => { Op::BroadcastDiv(lhs, rhs) => {
return Err(Error::BackwardNotSupported { let lhs_grad = grad.broadcast_div(rhs)?;
op: "broadcast_div", let lhs_sum_grad = grads.or_insert(lhs)?;
}) *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
} }
Op::Embedding(_lhs, _rhs) => { Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" }) return Err(Error::BackwardNotSupported { op: "embedding" })
@ -966,9 +972,8 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Neg(arg) => { Op::Neg(arg) => {
let arg_grad = grad.neg()?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.sub(&grad)?
} }
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "reshape" }), Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "reshape" }),