Backprop for narrow.

This commit is contained in:
laurent
2023-06-24 15:17:57 +01:00
parent fbbf3951dd
commit a6ca9baf3c

View File

@ -968,9 +968,15 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
} }
Op::Cat(_args, _dim) => { Op::Cat(args, dim) => {
// TODO: Use narrow here. let mut start_idx = 0;
return Err(Error::BackwardNotSupported { op: "cat" }); for arg in args {
let len = arg.dims()[*dim];
let arg_grad = grad.narrow(*dim, start_idx, len)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?;
start_idx += len;
}
} }
Op::ToDType(arg) => { Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;