mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add grads for interpolate1d (#1742)
* add backprop for interpolate1d * fix clippy lint * correct fix clippy lint
This commit is contained in:
@ -113,7 +113,7 @@ impl Tensor {
|
|||||||
| Op::Unary(_node, UnaryOp::Floor)
|
| Op::Unary(_node, UnaryOp::Floor)
|
||||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
| Op::UpsampleNearest1D(node)
|
| Op::UpsampleNearest1D { arg: node, .. }
|
||||||
| Op::UpsampleNearest2D { arg: node, .. }
|
| Op::UpsampleNearest2D { arg: node, .. }
|
||||||
| Op::AvgPool2D { arg: node, .. }
|
| Op::AvgPool2D { arg: node, .. }
|
||||||
| Op::MaxPool2D { arg: node, .. }
|
| Op::MaxPool2D { arg: node, .. }
|
||||||
@ -348,9 +348,18 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
}
|
}
|
||||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest1D { arg, target_size } => {
|
||||||
op: "upsample-nearest1d",
|
let (_n, c, size) = arg.dims3()?;
|
||||||
})?,
|
if target_size % size != 0 {
|
||||||
|
crate::bail!("backward not supported for non integer upscaling factors")
|
||||||
|
}
|
||||||
|
let scale = target_size / size;
|
||||||
|
|
||||||
|
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
||||||
|
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = conv_sum;
|
||||||
|
}
|
||||||
Op::UpsampleNearest2D {
|
Op::UpsampleNearest2D {
|
||||||
arg,
|
arg,
|
||||||
target_h,
|
target_h,
|
||||||
|
@ -132,7 +132,10 @@ pub enum Op {
|
|||||||
stride: (usize, usize),
|
stride: (usize, usize),
|
||||||
},
|
},
|
||||||
|
|
||||||
UpsampleNearest1D(Tensor),
|
UpsampleNearest1D {
|
||||||
|
arg: Tensor,
|
||||||
|
target_size: usize,
|
||||||
|
},
|
||||||
UpsampleNearest2D {
|
UpsampleNearest2D {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
target_h: usize,
|
target_h: usize,
|
||||||
|
@ -1015,7 +1015,7 @@ impl Tensor {
|
|||||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||||
let (n, c, _l) = self.dims3()?;
|
let (n, c, _l) = self.dims3()?;
|
||||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.upsample_nearest1d(self.layout(), target_size)?;
|
.upsample_nearest1d(self.layout(), target_size)?;
|
||||||
|
@ -283,6 +283,39 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
||||||
|
let y = x.interpolate1d(12)?.reshape(36)?;
|
||||||
|
|
||||||
|
println!("y: {}", y.unsqueeze(1)?);
|
||||||
|
#[rustfmt::skip]
|
||||||
|
let z = Tensor::new(
|
||||||
|
&[
|
||||||
|
1_f32, 02., 03., 04.,
|
||||||
|
05., 06., 07., 08.,
|
||||||
|
09., 10., 11., 12.,
|
||||||
|
13., 14., 15., 16.,
|
||||||
|
17., 18., 19., 20.,
|
||||||
|
21., 22., 23., 24.,
|
||||||
|
25., 26., 27., 28.,
|
||||||
|
29., 30., 31., 32.,
|
||||||
|
33., 34., 35., 36.,
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||||
|
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
|
||||||
|
println!("grad: {grad_x}");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(grad_x, 4)?,
|
||||||
|
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
||||||
|
);
|
||||||
|
|
||||||
// manually checked: see comments
|
// manually checked: see comments
|
||||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||||
|
Reference in New Issue
Block a user