mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Basic support for broadcasting backprop.
This commit is contained in:
@ -904,7 +904,7 @@ impl Tensor {
|
|||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
|
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
||||||
}
|
}
|
||||||
Op::Mul(lhs, rhs) => {
|
Op::Mul(lhs, rhs) => {
|
||||||
let lhs_grad = grad.mul(rhs)?;
|
let lhs_grad = grad.mul(rhs)?;
|
||||||
@ -922,15 +922,17 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::BroadcastAdd(_lhs, _rhs) => {
|
Op::BroadcastAdd(lhs, rhs) => {
|
||||||
return Err(Error::BackwardNotSupported {
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
op: "broadcast_add",
|
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
||||||
})
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?;
|
||||||
}
|
}
|
||||||
Op::BroadcastSub(_lhs, _rhs) => {
|
Op::BroadcastSub(lhs, rhs) => {
|
||||||
return Err(Error::BackwardNotSupported {
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
op: "broadcast_sub",
|
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
||||||
})
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
|
||||||
}
|
}
|
||||||
Op::BroadcastMul(_lhs, _rhs) => {
|
Op::BroadcastMul(_lhs, _rhs) => {
|
||||||
return Err(Error::BackwardNotSupported {
|
return Err(Error::BackwardNotSupported {
|
||||||
|
Reference in New Issue
Block a user