mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fix the cpu kernel for conv-transpose. (#643)
This commit is contained in:
@ -1199,25 +1199,23 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
let dst_s2 = out_w;
|
||||
let dst_s3 = 1;
|
||||
for b_idx in 0..p.b_size {
|
||||
for out_y in 0..out_h as i32 {
|
||||
for out_x in 0..out_w as i32 {
|
||||
let inp_x = out_x * p.stride as i32 - p.padding as i32;
|
||||
let inp_y = out_y * p.stride as i32 - p.padding as i32;
|
||||
for inp_y in 0..p.i_h {
|
||||
for inp_x in 0..p.i_w {
|
||||
let out_x = (inp_x * p.stride) as i32 - p.padding as i32;
|
||||
let out_y = (inp_y * p.stride) as i32 - p.padding as i32;
|
||||
for k_y in 0..p.k_h as i32 {
|
||||
for k_x in 0..p.k_w as i32 {
|
||||
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
|
||||
let inp_y = inp_y + k_y;
|
||||
let inp_x = inp_x + k_x;
|
||||
if inp_x < 0 || inp_y < 0 {
|
||||
let out_y = out_y + k_y;
|
||||
let out_x = out_x + k_x;
|
||||
if out_x < 0 || out_y < 0 {
|
||||
continue;
|
||||
}
|
||||
let inp_x = inp_x as usize;
|
||||
let inp_y = inp_y as usize;
|
||||
if inp_x < p.i_w && inp_y < p.i_h {
|
||||
let out_x = out_x as usize;
|
||||
let out_y = out_y as usize;
|
||||
if out_x < out_w && out_y < out_h {
|
||||
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
|
||||
let dst_index = b_idx * dst_s0
|
||||
+ out_y as usize * dst_s2
|
||||
+ out_x as usize * dst_s3;
|
||||
let dst_index = b_idx * dst_s0 + out_y * dst_s2 + out_x * dst_s3;
|
||||
for c_out in 0..p.c_out {
|
||||
for c_in in 0..p.c_in {
|
||||
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::{test_device, test_utils, Device, Tensor};
|
||||
use candle_core::{test_device, test_utils, Device, IndexOp, Tensor};
|
||||
|
||||
/* This test is based on the following script.
|
||||
import torch
|
||||
@ -76,6 +76,11 @@ print(t.flatten())
|
||||
print(w.flatten())
|
||||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
|
||||
w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose2d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -117,6 +122,33 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
if dev.is_cpu() {
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -287,16 +319,18 @@ fn conv2d_grad() -> Result<()> {
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&grad_t.flatten_all()?, 4)?,
|
||||
// THIS IS WRONG AT THE MOMENT
|
||||
[
|
||||
1.7442, -10.1747, -9.9426, 0.0, 0.0, -1.7046, -21.2248, 30.8435, 0.0, 0.0, -18.713,
|
||||
-1.0547, -7.8746, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 169.3047,
|
||||
46.0812, 40.6937, 0.0, 0.0, -85.8156, 4.537, 53.2871, 0.0, 0.0, -59.632, -35.9725,
|
||||
-7.1689, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 48.823, 8.9794,
|
||||
42.3011, 0.0, 0.0, -58.9268, 32.907, -50.6863, 0.0, 0.0, -0.9706, -3.9175, -4.2594,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 72.8229, 25.8492, 8.9871,
|
||||
0.0, 0.0, -136.2584, 40.1739, 88.9583, 0.0, 0.0, -53.465, -40.7102, -24.9406, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
||||
9.2868, -2.8352, -5.7117, 3.3817, -7.7094, -19.1549, 7.016, 29.1037, 9.3411, 34.7339,
|
||||
-22.8726, 24.3502, -39.88, -14.007, 21.076, 9.9419, 13.6333, -34.6796, 11.2073,
|
||||
-6.2617, 7.7209, -6.3224, -16.6373, -1.0837, -20.2215, 21.7302, -0.3744, -4.0573,
|
||||
5.8163, -3.6529, -30.7319, 14.5468, 87.699, 31.6035, 4.5304, -89.785, -75.3709,
|
||||
-57.4327, -7.5602, 92.9585, 18.791, -4.6311, -159.7521, -42.4656, -47.2644, 52.8768,
|
||||
37.3172, 48.9978, 12.8192, 2.014, -8.9826, 20.1759, 16.621, 12.0599, 15.3849, 19.9979,
|
||||
2.5725, -15.2197, 72.6244, -10.7496, 2.2541, -31.2003, 3.753, -0.2049, 9.7574, -0.6824,
|
||||
5.2107, -40.4361, -22.5891, -61.6085, 17.2837, 20.4149, 37.5454, 5.2262, 6.8126,
|
||||
23.5361, 23.6173, -9.9866, -9.1324, 4.8664, -35.0617, -26.1023, 63.4757, 25.8144,
|
||||
-39.2069, -70.6834, -46.9565, 2.3252, 41.8093, 82.4205, -28.626, -11.7812, -35.3284,
|
||||
-10.2771, -28.5694, -9.1258, 7.213, -9.0459, -9.6222, -11.2544
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
|
Reference in New Issue
Block a user