Simplify the conv1d and conv2d code. (#352)

This commit is contained in:
Laurent Mazare
2023-08-08 23:10:59 +02:00
committed by GitHub
parent b9864e1357
commit cf965ecaa8

View File

@ -997,7 +997,6 @@ impl<'a> Map2 for Conv1D<'a> {
(0, inp_stride) // This value never gets used anyway (0, inp_stride) // This value never gets used anyway
}; };
let k_stride = k_l.stride(); let k_stride = k_l.stride();
let k_over_2 = p.k_size / 2;
let l_out = p.l_out(); let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let mut dst = vec![T::zero(); dst_elems]; let mut dst = vec![T::zero(); dst_elems];
@ -1011,10 +1010,9 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + dst_l; let dst_idx = dst_idx + dst_l;
let mut d = T::zero(); let mut d = T::zero();
for offset in 0..p.k_size { for offset in 0..p.k_size {
let src_l_plus = p.stride * dst_l + offset + k_over_2 - p.padding; let src_l = (p.stride * dst_l + offset)
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] .saturating_sub(p.padding)
if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in { .min(p.l_in - 1);
let src_l = src_l_plus - k_over_2;
for src_c_idx in 0..p.c_in { for src_c_idx in 0..p.c_in {
let inp_idx = let inp_idx =
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
@ -1024,7 +1022,6 @@ impl<'a> Map2 for Conv1D<'a> {
d += inp[inp_idx] * k[k_idx] d += inp[inp_idx] * k[k_idx]
} }
} }
}
dst[dst_idx] = d dst[dst_idx] = d
} }
} }
@ -1064,15 +1061,13 @@ impl<'a> Map2 for Conv2D<'a> {
let mut d = T::zero(); let mut d = T::zero();
for offset_h in 0..p.k_h { for offset_h in 0..p.k_h {
// TODO: Handle the case where padding is larger than p.k_h / 2. // TODO: Handle the case where padding is larger than p.k_h / 2.
let src_h_plus = p.stride * dst_h + offset_h + p.k_h / 2 - p.padding; let src_h = (p.stride * dst_h + offset_h)
if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h { .saturating_sub(p.padding)
let src_h = src_h_plus - p.k_h / 2; .min(p.i_h - 1);
for offset_w in 0..p.k_w { for offset_w in 0..p.k_w {
let src_w_plus = let src_w = (p.stride * dst_w + offset_w)
p.stride * dst_w + offset_w + p.k_w / 2 - p.padding; .saturating_sub(p.padding)
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] .min(p.i_w - 1);
if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w {
let src_w = src_w_plus - p.k_w / 2;
for src_c_idx in 0..p.c_in { for src_c_idx in 0..p.c_in {
let inp_idx = inp_idx let inp_idx = inp_idx
+ src_c_idx * inp_stride[1] + src_c_idx * inp_stride[1]
@ -1086,8 +1081,6 @@ impl<'a> Map2 for Conv2D<'a> {
} }
} }
} }
}
}
dst[dst_idx] = d dst[dst_idx] = d
} }
} }