diff --git a/src/tensor.rs b/src/tensor.rs index b40ed886..b25a23c2 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -968,9 +968,15 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Cat(_args, _dim) => { - // TODO: Use narrow here. - return Err(Error::BackwardNotSupported { op: "cat" }); + Op::Cat(args, dim) => { + let mut start_idx = 0; + for arg in args { + let len = arg.dims()[*dim]; + let arg_grad = grad.narrow(*dim, start_idx, len)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)?; + start_idx += len; + } } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?;