From d0a330448d7c7dad242f2a6bafca29b8f53dc119 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 29 Aug 2023 10:17:59 +0100 Subject: [PATCH] Backprop support for pooling ops. (#652) * Backprop support for pooling ops. * max-pool gradient. --- candle-core/src/backprop.rs | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 4366f3b6..9ecdee4f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -219,8 +219,41 @@ impl Tensor { Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { op: "conv-transpose2d", })?, - Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, - Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?, + Op::AvgPool2D { + arg, + kernel_size, + stride, + } => { + if kernel_size != stride { + crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}") + } + let (_n, _c, h, w) = arg.dims4()?; + let grad_arg = grad.upsample_nearest2d(h, w)?; + let grad_arg = + (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } + Op::MaxPool2D { + arg, + kernel_size, + stride, + } => { + if kernel_size != stride { + crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}") + } + let (_n, _c, h, w) = arg.dims4()?; + // For computing the max-pool gradient, we compute a mask where a 1 means + // that the element is the maximum, then we apply this mask to the + // upsampled gradient (taking into account that multiple max may exist so + // we scale the gradient for this case). + let node_upsampled = node.upsample_nearest2d(h, w)?; + let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?; + let avg = mask.avg_pool2d(*kernel_size, *stride)?; + let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + } Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?,