More backprop support for broadcasting ops.

This commit is contained in:
laurent
2023-06-23 16:35:10 +01:00
parent d839d5d9fd
commit 69f91b36f9

View File

@ -934,15 +934,21 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
}
Op::BroadcastMul(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_mul",
})
Op::BroadcastMul(lhs, rhs) => {
let lhs_grad = grad.broadcast_mul(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
let rhs_grad = grad.broadcast_mul(lhs)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
}
Op::BroadcastDiv(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported {
op: "broadcast_div",
})
Op::BroadcastDiv(lhs, rhs) => {
let lhs_grad = grad.broadcast_div(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
@ -966,9 +972,8 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Neg(arg) => {
let arg_grad = grad.neg()?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
*sum_grad = sum_grad.sub(&grad)?
}
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "reshape" }),