diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 718b071c..4eb57bc7 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -206,18 +206,55 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { const OP: &'static str = "conv1d"; - fn f( + fn f( &self, - _inp: &[T], - _inp_l: &Layout, - _k: &[T], - _k_l: &Layout, + inp: &[T], + inp_l: &Layout, + k: &[T], + k_l: &Layout, ) -> Result> { + // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc). let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let k = &k[k_l.start_offset()..]; + let inp_stride = inp_l.stride(); + let (inp_stride0, inp_stride) = if inp_stride.len() == 3 { + (inp_stride[0], &inp_stride[1..]) + } else { + (0, inp_stride) // This value never gets used anyway + }; + let k_stride = k_l.stride(); + let k_over_2 = p.k_size / 2; let l_out = p.l_out(); - let out_elems = p.c_out * l_out * p.b_size.unwrap_or(1); - let dst = vec![T::zero(); out_elems]; - // TODO: actually implement the ops. + let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let mut dst = vec![T::zero(); dst_elems]; + // The output shape is [b_size, c_out, l_out] + for b_idx in 0..p.b_size.unwrap_or(1) { + let inp_idx = b_idx * inp_stride0; + let dst_idx = b_idx * p.c_out * l_out; + for dst_c_idx in 0..p.c_out { + let dst_idx = dst_idx + dst_c_idx * l_out; + for dst_l in 0..l_out { + let dst_idx = dst_idx + dst_l; + let mut d = T::zero(); + for offset in 0..p.k_size { + // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] + if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in { + let src_l = dst_l + offset - k_over_2; + for src_c_idx in 0..p.c_in { + let inp_idx = + inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset * k_stride[2]; + d += inp[inp_idx] * k[k_idx] + } + } + } + dst[dst_idx] = d + } + } + } Ok(dst) } }