mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add grads for interpolate1d (#1742)
* add backprop for interpolate1d * fix clippy lint * correct fix clippy lint
This commit is contained in:
@ -113,7 +113,7 @@ impl Tensor {
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest1D { arg: node, .. }
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
@ -348,9 +348,18 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest1D { arg, target_size } => {
|
||||
let (_n, c, size) = arg.dims3()?;
|
||||
if target_size % size != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale = target_size / size;
|
||||
|
||||
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
|
@ -132,7 +132,10 @@ pub enum Op {
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest1D {
|
||||
arg: Tensor,
|
||||
target_size: usize,
|
||||
},
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
|
@ -1015,7 +1015,7 @@ impl Tensor {
|
||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||
let (n, c, _l) = self.dims3()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest1d(self.layout(), target_size)?;
|
||||
|
Reference in New Issue
Block a user