diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 250e2721..fa24c434 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1031,8 +1031,10 @@ impl<'a> Map2 for Conv1D<'a> { let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let l_out = p.l_out(); let dst_elems = p.c_out * l_out * p.b_size; - let mut dst = vec![T::zero(); dst_elems]; // The output shape is [b_size, c_out, l_out] + let mut dst = vec![T::zero(); dst_elems]; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in]; for b_idx in 0..p.b_size { for src_l in 0..p.l_in { @@ -1042,6 +1044,7 @@ impl<'a> Map2 for Conv1D<'a> { } } } + for offset in 0..p.k_size { for dst_c_idx in 0..p.c_out { let dst_idx = dst_c_idx * l_out; @@ -1073,13 +1076,7 @@ struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); impl<'a> Map2 for Conv2D<'a> { const OP: &'static str = "conv2d"; - fn f( - &self, - inp: &[T], - inp_l: &Layout, - k: &[T], - k_l: &Layout, - ) -> Result> { + fn f(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; @@ -1087,43 +1084,67 @@ impl<'a> Map2 for Conv2D<'a> { let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?; let (out_h, out_w) = (p.out_h(), p.out_w()); + // Output shape: [b_size, c_out, out_h, out_w]. let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; + let cont_s0 = p.i_h * p.i_w * p.c_in; + let cont_s1 = p.i_w * p.c_in; + let cont_s2 = p.c_in; for b_idx in 0..p.b_size { - let inp_idx = b_idx * inp_s0; - let dst_idx = b_idx * p.c_out * out_h * out_w; - for dst_c_idx in 0..p.c_out { - let dst_idx = dst_idx + dst_c_idx * out_h * out_w; - for dst_h in 0..out_h { - let dst_idx = dst_idx + dst_h * out_w; - for dst_w in 0..out_w { - let dst_idx = dst_idx + dst_w; - let mut d = T::zero(); - for offset_h in 0..p.k_h { - let src_h = (p.stride * dst_h + offset_h) - .saturating_sub(p.padding) - .min(p.i_h - 1); - for offset_w in 0..p.k_w { - let src_w = (p.stride * dst_w + offset_w) - .saturating_sub(p.padding) - .min(p.i_w - 1); - for src_c_idx in 0..p.c_in { - let inp_idx = inp_idx - + src_c_idx * inp_s1 - + src_h * inp_s2 - + src_w * inp_s3; - let k_idx = dst_c_idx * k_s0 - + src_c_idx * k_s1 - + offset_h * k_s2 - + offset_w * k_s3; - d += inp[inp_idx] * k[k_idx] - } - } - } - dst[dst_idx] = d + for h_idx in 0..p.i_h { + for w_idx in 0..p.i_w { + for c_idx in 0..p.c_in { + let src_idx = + b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3; + let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx; + inp_cont[dst_idx] = inp[src_idx] } } } } + + for offset_h in 0..p.k_h { + for offset_w in 0..p.k_w { + for dst_c_idx in 0..p.c_out { + let dst_idx = dst_c_idx * out_w * out_h; + let k_cont = (0..p.c_in) + .map(|c_in_idx| { + k[dst_c_idx * k_s0 + + c_in_idx * k_s1 + + offset_h * k_s2 + + offset_w * k_s3] + }) + .collect::>(); + for b_idx in 0..p.b_size { + let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w; + for dst_h in 0..out_h { + let dst_idx = dst_idx + dst_h * out_w; + let src_h = (p.stride * dst_h + offset_h) + .saturating_sub(p.padding) + .min(p.i_h - 1); + for dst_w in 0..out_w { + let dst_idx = dst_idx + dst_w; + let src_w = (p.stride * dst_w + offset_w) + .saturating_sub(p.padding) + .min(p.i_w - 1); + let inp_cont = &inp_cont + [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..]; + assert!(inp_cont.len() >= p.c_in); + assert!(k_cont.len() >= p.c_in); + let mut d = T::zero(); + unsafe { + T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) + } + dst[dst_idx] += d + } + } + } + } + } + } + Ok(dst) } }