From a6ca9baf3c959e1a41bfce7f223ef2166374fe8d Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 24 Jun 2023 15:17:57 +0100 Subject: [PATCH] Backprop for narrow. --- src/tensor.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/tensor.rs b/src/tensor.rs index b40ed886..b25a23c2 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -968,9 +968,15 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Cat(_args, _dim) => { - // TODO: Use narrow here. - return Err(Error::BackwardNotSupported { op: "cat" }); + Op::Cat(args, dim) => { + let mut start_idx = 0; + for arg in args { + let len = arg.dims()[*dim]; + let arg_grad = grad.narrow(*dim, start_idx, len)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)?; + start_idx += len; + } } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?;