mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Use zero padding in conv1d and conv2d (same as pytorch). (#408)
This commit is contained in:
@ -1057,9 +1057,11 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||||||
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
|
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
|
||||||
for dst_l in 0..l_out {
|
for dst_l in 0..l_out {
|
||||||
let dst_idx = dst_idx + dst_l;
|
let dst_idx = dst_idx + dst_l;
|
||||||
let src_l = (p.stride * dst_l + offset)
|
let src_l = p.stride * dst_l + offset;
|
||||||
.saturating_sub(p.padding)
|
if src_l < p.padding || src_l >= p.padding + p.l_in {
|
||||||
.min(p.l_in - 1);
|
continue;
|
||||||
|
}
|
||||||
|
let src_l = src_l - p.padding;
|
||||||
let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
|
let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
|
||||||
assert!(inp_cont.len() >= p.c_in);
|
assert!(inp_cont.len() >= p.c_in);
|
||||||
assert!(k_cont.len() >= p.c_in);
|
assert!(k_cont.len() >= p.c_in);
|
||||||
@ -1132,14 +1134,18 @@ impl<'a> Map2 for Conv2D<'a> {
|
|||||||
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
|
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
|
||||||
for dst_h in 0..out_h {
|
for dst_h in 0..out_h {
|
||||||
let dst_idx = dst_idx + dst_h * out_w;
|
let dst_idx = dst_idx + dst_h * out_w;
|
||||||
let src_h = (p.stride * dst_h + offset_h)
|
let src_h = p.stride * dst_h + offset_h;
|
||||||
.saturating_sub(p.padding)
|
if src_h < p.padding || src_h >= p.i_h + p.padding {
|
||||||
.min(p.i_h - 1);
|
continue;
|
||||||
|
}
|
||||||
|
let src_h = src_h - p.padding;
|
||||||
for dst_w in 0..out_w {
|
for dst_w in 0..out_w {
|
||||||
let dst_idx = dst_idx + dst_w;
|
let dst_idx = dst_idx + dst_w;
|
||||||
let src_w = (p.stride * dst_w + offset_w)
|
let src_w = p.stride * dst_w + offset_w;
|
||||||
.saturating_sub(p.padding)
|
if src_w < p.padding || src_w >= p.i_w + p.padding {
|
||||||
.min(p.i_w - 1);
|
continue;
|
||||||
|
}
|
||||||
|
let src_w = src_w - p.padding;
|
||||||
let inp_cont = &inp_cont
|
let inp_cont = &inp_cont
|
||||||
[b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
|
[b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
|
||||||
assert!(inp_cont.len() >= p.c_in);
|
assert!(inp_cont.len() >= p.c_in);
|
||||||
|
Reference in New Issue
Block a user