mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Backprop for narrow.
This commit is contained in:
@ -968,9 +968,15 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Cat(_args, _dim) => {
|
Op::Cat(args, dim) => {
|
||||||
// TODO: Use narrow here.
|
let mut start_idx = 0;
|
||||||
return Err(Error::BackwardNotSupported { op: "cat" });
|
for arg in args {
|
||||||
|
let len = arg.dims()[*dim];
|
||||||
|
let arg_grad = grad.narrow(*dim, start_idx, len)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&arg_grad)?;
|
||||||
|
start_idx += len;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
Reference in New Issue
Block a user