mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
More backprop support for broadcasting ops.
This commit is contained in:
@ -934,15 +934,21 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
|
||||
}
|
||||
Op::BroadcastMul(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_mul",
|
||||
})
|
||||
Op::BroadcastMul(lhs, rhs) => {
|
||||
let lhs_grad = grad.broadcast_mul(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
|
||||
let rhs_grad = grad.broadcast_mul(lhs)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
||||
}
|
||||
Op::BroadcastDiv(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_div",
|
||||
})
|
||||
Op::BroadcastDiv(lhs, rhs) => {
|
||||
let lhs_grad = grad.broadcast_div(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
|
||||
let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
@ -966,9 +972,8 @@ impl Tensor {
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Neg(arg) => {
|
||||
let arg_grad = grad.neg()?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||
|
Reference in New Issue
Block a user