mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Backprop for narrow.
This commit is contained in:
@ -968,9 +968,15 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Cat(_args, _dim) => {
|
||||
// TODO: Use narrow here.
|
||||
return Err(Error::BackwardNotSupported { op: "cat" });
|
||||
Op::Cat(args, dim) => {
|
||||
let mut start_idx = 0;
|
||||
for arg in args {
|
||||
let len = arg.dims()[*dim];
|
||||
let arg_grad = grad.narrow(*dim, start_idx, len)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?;
|
||||
start_idx += len;
|
||||
}
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user