mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Basic support for broadcasting backprop.
This commit is contained in:
@ -904,7 +904,7 @@ impl Tensor {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
||||
}
|
||||
Op::Mul(lhs, rhs) => {
|
||||
let lhs_grad = grad.mul(rhs)?;
|
||||
@ -922,15 +922,17 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::BroadcastAdd(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_add",
|
||||
})
|
||||
Op::BroadcastAdd(lhs, rhs) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?;
|
||||
}
|
||||
Op::BroadcastSub(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
op: "broadcast_sub",
|
||||
})
|
||||
Op::BroadcastSub(lhs, rhs) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
|
||||
}
|
||||
Op::BroadcastMul(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported {
|
||||
|
Reference in New Issue
Block a user