diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 07fc78fc..f52d53b1 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1193,41 +1193,78 @@ impl<'a> Map2 for ConvTranspose2D<'a> { let (out_h, out_w) = (p.out_h(), p.out_w()); // Output shape: [b_size, c_out, out_h, out_w]. - let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; let dst_s0 = p.c_out * out_h * out_w; let dst_s1 = out_h * out_w; let dst_s2 = out_w; let dst_s3 = 1; + + // TODO: Avoid making this copy if `inp` already has the appropriate layout. + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; + let cont_s0 = p.i_h * p.i_w * p.c_in; + let cont_s1 = p.i_w * p.c_in; + let cont_s2 = p.c_in; for b_idx in 0..p.b_size { - for inp_y in 0..p.i_h { - for inp_x in 0..p.i_w { - let out_x = (inp_x * p.stride) as i32 - p.padding as i32; - let out_y = (inp_y * p.stride) as i32 - p.padding as i32; - for k_y in 0..p.k_h as i32 { - for k_x in 0..p.k_w as i32 { - let k_index = k_y as usize * k_s2 + k_x as usize * k_s3; - let out_y = out_y + k_y; - let out_x = out_x + k_x; - if out_x < 0 || out_y < 0 { - continue; - } - let out_x = out_x as usize; - let out_y = out_y as usize; - if out_x < out_w && out_y < out_h { - let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3; - let dst_index = b_idx * dst_s0 + out_y * dst_s2 + out_x * dst_s3; - for c_out in 0..p.c_out { - for c_in in 0..p.c_in { - let k_index = k_index + c_out * k_s1 + c_in * k_s0; - let dst_index = dst_index + c_out * dst_s1; - let inp_index = inp_index + c_in * inp_s1; - dst[dst_index] += k[k_index] * inp[inp_index] + for h_idx in 0..p.i_h { + for w_idx in 0..p.i_w { + for c_idx in 0..p.c_in { + let src_idx = + b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3; + let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx; + inp_cont[dst_idx] = inp[src_idx] + } + } + } + } + let num_threads = crate::utils::get_num_threads(); + + for k_y in 0..p.k_h { + for k_x in 0..p.k_w { + crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + let k_cont = (0..p.c_in) + .map(|c_in_idx| { + k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] + }) + .collect::>(); + for b_idx in 0..p.b_size { + for inp_y in 0..p.i_h { + for inp_x in 0..p.i_w { + let out_x = inp_x * p.stride + k_x; + let out_y = inp_y * p.stride + k_y; + if out_x < p.padding || out_y < p.padding { + continue; + } + let out_x = out_x - p.padding; + let out_y = out_y - p.padding; + if out_x < out_w && out_y < out_h { + let inp_cont = &inp_cont + [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..]; + let dst_idx = b_idx * dst_s0 + + out_y * dst_s2 + + out_x * dst_s3 + + dst_c_idx * dst_s1; + let mut d = T::zero(); + unsafe { + T::vec_dot( + inp_cont.as_ptr(), + k_cont.as_ptr(), + &mut d, + p.c_in, + ) + } + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to + // parallelise the different tasks so no two threads can try to + // write at the same location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d } } } } } - } + }) } } Ok(dst)