mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Backprop for conv2d. (#638)
* Start adding backprop for conv2d. * Backprop for conv2d. * Bugfix + start adding a conv2d test. * Conv2d backprop testing. * More conv fixes.
This commit is contained in:
@ -192,7 +192,30 @@ impl Tensor {
|
|||||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||||
}
|
}
|
||||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||||
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
|
Op::Conv2D {
|
||||||
|
arg,
|
||||||
|
kernel,
|
||||||
|
padding,
|
||||||
|
stride,
|
||||||
|
} => {
|
||||||
|
// The output height for conv_transpose2d is:
|
||||||
|
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
|
||||||
|
let grad_h = grad.dim(2)?;
|
||||||
|
let k_h = kernel.dim(2)?;
|
||||||
|
let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
|
||||||
|
let out_padding = arg.dim(2)? - out_size;
|
||||||
|
let grad_arg =
|
||||||
|
grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||||
|
|
||||||
|
let grad_kernel = arg
|
||||||
|
.transpose(0, 1)?
|
||||||
|
.conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
|
||||||
|
.transpose(0, 1)?;
|
||||||
|
let sum_grad = grads.or_insert(kernel)?;
|
||||||
|
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||||
|
}
|
||||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "conv-transpose2d",
|
op: "conv-transpose2d",
|
||||||
})?,
|
})?,
|
||||||
|
@ -71,18 +71,14 @@ pub struct ParamsConvTranspose2D {
|
|||||||
impl ParamsConvTranspose2D {
|
impl ParamsConvTranspose2D {
|
||||||
pub(crate) fn out_h(&self) -> usize {
|
pub(crate) fn out_h(&self) -> usize {
|
||||||
let dilation = 1;
|
let dilation = 1;
|
||||||
(self.i_h - 1) * self.stride - 2 * self.padding
|
(self.i_h - 1) * self.stride + dilation * (self.k_h - 1) + self.output_padding + 1
|
||||||
+ dilation * (self.k_h - 1)
|
- 2 * self.padding
|
||||||
+ self.output_padding
|
|
||||||
+ 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn out_w(&self) -> usize {
|
pub(crate) fn out_w(&self) -> usize {
|
||||||
let dilation = 1;
|
let dilation = 1;
|
||||||
(self.i_w - 1) * self.stride - 2 * self.padding
|
(self.i_w - 1) * self.stride + dilation * (self.k_w - 1) + self.output_padding + 1
|
||||||
+ dilation * (self.k_w - 1)
|
- 2 * self.padding
|
||||||
+ self.output_padding
|
|
||||||
+ 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
pub(crate) fn out_dims(&self) -> Vec<usize> {
|
||||||
|
@ -1204,7 +1204,7 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
|||||||
let inp_x = out_x * p.stride as i32 - p.padding as i32;
|
let inp_x = out_x * p.stride as i32 - p.padding as i32;
|
||||||
let inp_y = out_y * p.stride as i32 - p.padding as i32;
|
let inp_y = out_y * p.stride as i32 - p.padding as i32;
|
||||||
for k_y in 0..p.k_h as i32 {
|
for k_y in 0..p.k_h as i32 {
|
||||||
for k_x in 0..p.k_h as i32 {
|
for k_x in 0..p.k_w as i32 {
|
||||||
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
|
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
|
||||||
let inp_y = inp_y + k_y;
|
let inp_y = inp_y + k_y;
|
||||||
let inp_x = inp_x + k_x;
|
let inp_x = inp_x + k_x;
|
||||||
@ -1215,9 +1215,11 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
|||||||
let inp_y = inp_y as usize;
|
let inp_y = inp_y as usize;
|
||||||
if inp_x < p.i_w && inp_y < p.i_h {
|
if inp_x < p.i_w && inp_y < p.i_h {
|
||||||
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
|
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
|
||||||
let dst_index = b_idx * dst_s0 + inp_y * dst_s2 + inp_x * dst_s3;
|
let dst_index = b_idx * dst_s0
|
||||||
for c_out in 0..k_s0 {
|
+ out_y as usize * dst_s2
|
||||||
for c_in in 0..k_s1 {
|
+ out_x as usize * dst_s3;
|
||||||
|
for c_out in 0..p.c_out {
|
||||||
|
for c_in in 0..p.c_in {
|
||||||
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
|
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
|
||||||
let dst_index = dst_index + c_out * dst_s1;
|
let dst_index = dst_index + c_out * dst_s1;
|
||||||
let inp_index = inp_index + c_in * inp_s1;
|
let inp_index = inp_index + c_in * inp_s1;
|
||||||
|
@ -243,6 +243,78 @@ fn conv2d_non_square(dev: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conv2d_grad() -> Result<()> {
|
||||||
|
use candle_core::Var;
|
||||||
|
let dev = &Device::Cpu;
|
||||||
|
let t = Var::from_slice(
|
||||||
|
&[
|
||||||
|
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||||
|
1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, 0.2395,
|
||||||
|
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
|
||||||
|
0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, 1.3123,
|
||||||
|
1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, 1.1586,
|
||||||
|
0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, 0.2049,
|
||||||
|
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
|
||||||
|
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
|
||||||
|
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
|
||||||
|
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
|
||||||
|
],
|
||||||
|
(1, 4, 5, 5),
|
||||||
|
dev,
|
||||||
|
)?;
|
||||||
|
let w = Var::from_slice(
|
||||||
|
&[
|
||||||
|
-0.9325f32, 0.6451, -0.8537, 0.2378, 0.8764, -0.1832, 0.2987, -0.6488, -0.2273,
|
||||||
|
-2.4184, -0.1192, -0.4821, -0.5079, -0.5766, -2.4729, 1.6734, 0.4558, 0.2851, 1.1514,
|
||||||
|
-0.9013, 1.0662, -0.1817, -0.0259, 0.1709, 0.5367, 0.7513, 0.8086, -2.2586, -0.5027,
|
||||||
|
0.9141, -1.3086, -1.3343, -1.5669, -0.1657, 0.7958, 0.1432, 0.3896, -0.4501, 0.1667,
|
||||||
|
0.0714, -0.0952, 1.2970, -0.1674, -0.3178, 1.0677, 0.3060, 0.7080, 0.1914, 1.1679,
|
||||||
|
-0.3602, 1.9265, -1.8626, -0.5112, -0.0982, 0.2621, 0.6565, 0.5908, 1.0089, -0.1646,
|
||||||
|
1.8032, -0.6286, 0.2016, -0.3370, 1.2555, 0.8009, -0.6488, -0.4652, -1.5685, 1.5860,
|
||||||
|
0.5583, 0.4623, 0.6026,
|
||||||
|
],
|
||||||
|
(2, 4, 3, 3),
|
||||||
|
dev,
|
||||||
|
)?;
|
||||||
|
let res = t.conv2d(&w, 0, 1, 1)?;
|
||||||
|
let loss = res.sqr()?.sum_all()?;
|
||||||
|
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 741.12f32);
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
let grad_t = grads.get(&t).unwrap();
|
||||||
|
let grad_w = grads.get(&w).unwrap();
|
||||||
|
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||||
|
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?,
|
||||||
|
// THIS IS WRONG AT THE MOMENT
|
||||||
|
[
|
||||||
|
1.7442, -10.1747, -9.9426, 0.0, 0.0, -1.7046, -21.2248, 30.8435, 0.0, 0.0, -18.713,
|
||||||
|
-1.0547, -7.8746, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 169.3047,
|
||||||
|
46.0812, 40.6937, 0.0, 0.0, -85.8156, 4.537, 53.2871, 0.0, 0.0, -59.632, -35.9725,
|
||||||
|
-7.1689, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 48.823, 8.9794,
|
||||||
|
42.3011, 0.0, 0.0, -58.9268, 32.907, -50.6863, 0.0, 0.0, -0.9706, -3.9175, -4.2594,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 72.8229, 25.8492, 8.9871,
|
||||||
|
0.0, 0.0, -136.2584, 40.1739, 88.9583, 0.0, 0.0, -53.465, -40.7102, -24.9406, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&grad_w.flatten_all()?, 4)?,
|
||||||
|
[
|
||||||
|
-28.9232, -22.8833, -141.2296, 73.3462, 61.074, 47.8125, -20.0013, -73.7086, -41.8217,
|
||||||
|
-13.5919, 21.501, 28.7179, 28.5683, -46.8486, -90.1874, 143.6107, 16.6764, 7.4259,
|
||||||
|
18.8794, -90.8122, -20.2865, 54.7909, 82.6287, 22.943, 77.8084, -16.3928, -13.1977,
|
||||||
|
9.3442, -40.3869, -26.6153, 5.3344, -60.9081, 9.0869, -59.368, 7.081, 58.6391, 5.5476,
|
||||||
|
20.5152, 2.4985, -17.2466, -6.802, 22.2146, 30.1511, -7.5179, -37.4588, 5.6654,
|
||||||
|
22.5832, 9.0316, 47.0547, 17.6123, 37.3121, -98.1295, -14.6141, -4.7958, -6.3597,
|
||||||
|
44.6949, 23.3418, 8.3728, -13.52, 80.0522, -34.2403, -16.3648, -12.3139, 1.9195,
|
||||||
|
-33.6244, -14.102, -49.2305, -7.3853, 11.4995, -9.9826, 9.6588, 29.6042
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||||
|
Reference in New Issue
Block a user