mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Backprop support for pooling ops. (#652)
* Backprop support for pooling ops. * max-pool gradient.
This commit is contained in:
@ -219,8 +219,41 @@ impl Tensor {
|
|||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose2d",
|
op: "conv-transpose2d",
|
||||||
})?,
|
})?,
|
||||||
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
|
Op::AvgPool2D {
|
||||||
Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
|
arg,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
} => {
|
||||||
|
if kernel_size != stride {
|
||||||
|
crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||||
|
}
|
||||||
|
let (_n, _c, h, w) = arg.dims4()?;
|
||||||
|
let grad_arg = grad.upsample_nearest2d(h, w)?;
|
||||||
|
let grad_arg =
|
||||||
|
(grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
|
}
|
||||||
|
Op::MaxPool2D {
|
||||||
|
arg,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
} => {
|
||||||
|
if kernel_size != stride {
|
||||||
|
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
|
||||||
|
}
|
||||||
|
let (_n, _c, h, w) = arg.dims4()?;
|
||||||
|
// For computing the max-pool gradient, we compute a mask where a 1 means
|
||||||
|
// that the element is the maximum, then we apply this mask to the
|
||||||
|
// upsampled gradient (taking into account that multiple max may exist so
|
||||||
|
// we scale the gradient for this case).
|
||||||
|
let node_upsampled = node.upsample_nearest2d(h, w)?;
|
||||||
|
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
|
||||||
|
let avg = mask.avg_pool2d(*kernel_size, *stride)?;
|
||||||
|
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
|
}
|
||||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "upsample-nearest2d",
|
op: "upsample-nearest2d",
|
||||||
})?,
|
})?,
|
||||||
|
Reference in New Issue
Block a user