diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index c997d767..0c4e4597 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1039,12 +1039,59 @@ impl<'a> Map2 for Conv2D<'a> { const OP: &'static str = "conv2d"; fn f( &self, - _inp: &[T], - _inp_l: &Layout, - _k: &[T], - _k_l: &Layout, + inp: &[T], + inp_l: &Layout, + k: &[T], + k_l: &Layout, ) -> Result> { - todo!() + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let inp_stride = inp_l.stride(); + let k = &k[k_l.start_offset()..]; + let k_stride = k_l.stride(); + let (out_h, out_w) = (p.out_h(), p.out_w()); + + let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + for b_idx in 0..p.b_size { + let inp_idx = b_idx * inp_stride[0]; + let dst_idx = b_idx * p.c_out * out_h * out_w; + for dst_c_idx in 0..p.c_out { + let dst_idx = dst_idx + dst_c_idx * out_h * out_w; + for dst_h in 0..out_h { + let dst_idx = dst_idx + dst_h * out_w; + for dst_w in 0..out_h { + let dst_idx = dst_idx + dst_w; + let mut d = T::zero(); + for offset_h in 0..p.k_h { + let src_h_plus = p.stride * dst_h + offset_h; + if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h { + let src_h = src_h_plus - p.k_h / 2; + for offset_w in 0..p.k_w { + let src_w_plus = p.stride * dst_w + offset_w; + // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] + if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w { + let src_w = src_w_plus - p.k_w / 2; + for src_c_idx in 0..p.c_in { + let inp_idx = inp_idx + + src_c_idx * inp_stride[1] + + src_h * inp_stride[2] + + src_w * inp_stride[3]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset_h * k_stride[2] + + offset_w * k_stride[3]; + d += inp[inp_idx] * k[k_idx] + } + } + } + } + } + dst[dst_idx] = d + } + } + } + } + Ok(dst) } }