mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix the conv2d gradient computation. (#1214)
This commit is contained in:
@ -238,6 +238,13 @@ impl Tensor {
|
|||||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||||
.transpose(0, 1)?;
|
.transpose(0, 1)?;
|
||||||
let sum_grad = grads.or_insert(kernel)?;
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
|
let (_, _, k0, k1) = kernel.dims4()?;
|
||||||
|
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||||
|
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||||
|
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||||
|
} else {
|
||||||
|
grad_kernel
|
||||||
|
};
|
||||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
}
|
}
|
||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
|
@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||||
|
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||||
|
let loss = res.sqr()?.sum_all()?;
|
||||||
|
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
let grad_t = grads.get(&t).unwrap();
|
||||||
|
let grad_w = grads.get(&w).unwrap();
|
||||||
|
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||||
|
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||||
|
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||||
|
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||||
|
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||||
|
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||||
|
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||||
|
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||||
|
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||||
|
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-3.47, 7.44, 0.66],
|
||||||
|
[12.89, -3.4, -9.29],
|
||||||
|
[-14.16, -0.83, 7.14]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-3.23, 5.37, -3.02],
|
||||||
|
[-2.12, -11.24, 1.94],
|
||||||
|
[6.97, 7.2, 2.99]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[-4.04, -3.31, 4.87],
|
||||||
|
[-6.68, -5.68, 1.73],
|
||||||
|
[-5.54, 4.32, 0.52]
|
||||||
|
],
|
||||||
|
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user