mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Properly handle the stride in conv1d.
This commit is contained in:
@ -238,9 +238,10 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let mut d = T::zero();
|
||||
for offset in 0..p.k_size {
|
||||
let src_l_plus = p.stride * dst_l + offset;
|
||||
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
||||
if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in {
|
||||
let src_l = dst_l + offset - k_over_2;
|
||||
if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in {
|
||||
let src_l = src_l_plus - k_over_2;
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx =
|
||||
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
|
||||
|
Reference in New Issue
Block a user