mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix the dilated convolutions. (#659)
This commit is contained in:
@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let src_l = (p.stride * dst_l + offset) * p.dilation;
|
||||
let src_l = p.stride * dst_l + offset * p.dilation;
|
||||
if src_l < p.padding || src_l >= p.padding + p.l_in {
|
||||
continue;
|
||||
}
|
||||
@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
|
||||
for dst_h in 0..out_h {
|
||||
let dst_idx = dst_idx + dst_h * out_w;
|
||||
let src_h = (p.stride * dst_h + offset_h) * p.dilation;
|
||||
let src_h = p.stride * dst_h + offset_h * p.dilation;
|
||||
if src_h < p.padding || src_h >= p.i_h + p.padding {
|
||||
continue;
|
||||
}
|
||||
let src_h = src_h - p.padding;
|
||||
for dst_w in 0..out_w {
|
||||
let dst_idx = dst_idx + dst_w;
|
||||
let src_w = (p.stride * dst_w + offset_w) * p.dilation;
|
||||
let src_w = p.stride * dst_w + offset_w * p.dilation;
|
||||
if src_w < p.padding || src_w >= p.i_w + p.padding {
|
||||
continue;
|
||||
}
|
||||
|
@ -423,24 +423,24 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[28.34, -45.75, 7.32],
|
||||
[0.72, -35.28, 19.23],
|
||||
[-28.29, 20.89, -5.18]
|
||||
[28.34, -7.91, -45.75],
|
||||
[21.03, 3.86, 29.86],
|
||||
[0.72, -36.58, -35.28]
|
||||
],
|
||||
[
|
||||
[-16.04, -16.38, 32.12],
|
||||
[57.5, 25.81, 11.96],
|
||||
[-18.66, 8.48, -9.92]
|
||||
[-16.04, 11.53, -16.38],
|
||||
[29.62, -16.32, -48.35],
|
||||
[57.5, 28.29, 25.81]
|
||||
],
|
||||
[
|
||||
[2.93, 1.57, -23.76],
|
||||
[12.74, -26.2, -17.88],
|
||||
[-14.98, -9.35, 12.2]
|
||||
[2.93, -19.6, 1.57],
|
||||
[27.15, 53.88, -24.64],
|
||||
[12.74, -22.6, -26.2]
|
||||
],
|
||||
[
|
||||
[-0.18, -6.82, 20.79],
|
||||
[-2.54, 27.11, -10.11],
|
||||
[-0.41, -3.18, -0.07]
|
||||
[-0.18, -14.86, -6.82],
|
||||
[-19.55, -2.72, 45.9],
|
||||
[-2.54, 36.97, 27.11]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
@ -92,13 +92,13 @@ __device__ void conv2d(
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = (stride * dst_w + w_offset) * dilation;
|
||||
size_t src_w = stride * dst_w + w_offset * dilation;
|
||||
if (src_w < padding || src_w >= w_in + padding) {
|
||||
continue;
|
||||
}
|
||||
src_w -= padding;
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = (stride * dst_h + h_offset) * dilation;
|
||||
size_t src_h = stride * dst_h + h_offset * dilation;
|
||||
if (src_h < padding || src_h >= h_in + padding) {
|
||||
continue;
|
||||
}
|
||||
|
Reference in New Issue
Block a user