From 1cfc5d6d0c1a2d7cfa2cc62948fe324df3fd0ac3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 3 Nov 2023 14:23:53 +0100 Subject: [PATCH] Backprop support for conv1d (cpu only for now). (#1255) --- candle-core/src/backprop.rs | 39 ++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 30b9fb7c..da6dbb66 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -213,7 +213,44 @@ impl Tensor { let f_grad = pred.where_cond(&zeros, &grad)?; *f_sum_grad = f_sum_grad.add(&f_grad)?; } - Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::Conv1D { + arg, + kernel, + padding, + stride, + dilation, + } => { + // The output height for conv_transpose1d is: + // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1 + let grad_l_in = grad.dim(2)?; + let k_size = kernel.dim(2)?; + let out_size = + (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding; + let out_padding = arg.dim(2)? - out_size; + let grad_arg = grad.conv_transpose1d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + )?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = arg + .transpose(0, 1)? + .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0) = kernel.dims3()?; + let (_, _, g_k0) = grad_kernel.dims3()?; + let grad_kernel = if g_k0 != k0 { + grad_kernel.narrow(2, 0, k0)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } Op::Conv2D { arg, kernel,