mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Upsample grad (#1420)
* encode size of upsample in enum * working convolution method for limited 2d kernels * add test for sf 3 interpolation * add higher dimensional tests, fix to work with multichannel input * Remove commented out line. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -114,7 +114,7 @@ impl Tensor {
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
| Op::Copy(node)
|
||||
@ -350,9 +350,27 @@ impl Tensor {
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
target_w,
|
||||
} => {
|
||||
let (_n, c, h, w) = arg.dims4()?;
|
||||
if target_h % h != 0 || target_w % w != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale_h = target_h / h;
|
||||
let scale_w = target_w / w;
|
||||
|
||||
if scale_h != scale_w {
|
||||
crate::bail!("backward not supported for non uniform upscaling factors")
|
||||
};
|
||||
let kernel =
|
||||
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
|
Reference in New Issue
Block a user