mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Optimize the conv2d transpose cpu kernel. (#644)
* Optimize the conv2d transpose cpu kernel. * Use multiple cores.
This commit is contained in:
@ -1193,41 +1193,78 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||
|
||||
// Output shape: [b_size, c_out, out_h, out_w].
|
||||
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
let dst_s0 = p.c_out * out_h * out_w;
|
||||
let dst_s1 = out_h * out_w;
|
||||
let dst_s2 = out_w;
|
||||
let dst_s3 = 1;
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
|
||||
let cont_s0 = p.i_h * p.i_w * p.c_in;
|
||||
let cont_s1 = p.i_w * p.c_in;
|
||||
let cont_s2 = p.c_in;
|
||||
for b_idx in 0..p.b_size {
|
||||
for inp_y in 0..p.i_h {
|
||||
for inp_x in 0..p.i_w {
|
||||
let out_x = (inp_x * p.stride) as i32 - p.padding as i32;
|
||||
let out_y = (inp_y * p.stride) as i32 - p.padding as i32;
|
||||
for k_y in 0..p.k_h as i32 {
|
||||
for k_x in 0..p.k_w as i32 {
|
||||
let k_index = k_y as usize * k_s2 + k_x as usize * k_s3;
|
||||
let out_y = out_y + k_y;
|
||||
let out_x = out_x + k_x;
|
||||
if out_x < 0 || out_y < 0 {
|
||||
continue;
|
||||
}
|
||||
let out_x = out_x as usize;
|
||||
let out_y = out_y as usize;
|
||||
if out_x < out_w && out_y < out_h {
|
||||
let inp_index = b_idx * inp_s0 + inp_y * inp_s2 + inp_x * inp_s3;
|
||||
let dst_index = b_idx * dst_s0 + out_y * dst_s2 + out_x * dst_s3;
|
||||
for c_out in 0..p.c_out {
|
||||
for c_in in 0..p.c_in {
|
||||
let k_index = k_index + c_out * k_s1 + c_in * k_s0;
|
||||
let dst_index = dst_index + c_out * dst_s1;
|
||||
let inp_index = inp_index + c_in * inp_s1;
|
||||
dst[dst_index] += k[k_index] * inp[inp_index]
|
||||
for h_idx in 0..p.i_h {
|
||||
for w_idx in 0..p.i_w {
|
||||
for c_idx in 0..p.c_in {
|
||||
let src_idx =
|
||||
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
|
||||
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
|
||||
inp_cont[dst_idx] = inp[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for k_y in 0..p.k_h {
|
||||
for k_x in 0..p.k_w {
|
||||
crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| {
|
||||
k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
for inp_y in 0..p.i_h {
|
||||
for inp_x in 0..p.i_w {
|
||||
let out_x = inp_x * p.stride + k_x;
|
||||
let out_y = inp_y * p.stride + k_y;
|
||||
if out_x < p.padding || out_y < p.padding {
|
||||
continue;
|
||||
}
|
||||
let out_x = out_x - p.padding;
|
||||
let out_y = out_y - p.padding;
|
||||
if out_x < out_w && out_y < out_h {
|
||||
let inp_cont = &inp_cont
|
||||
[b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
|
||||
let dst_idx = b_idx * dst_s0
|
||||
+ out_y * dst_s2
|
||||
+ out_x * dst_s3
|
||||
+ dst_c_idx * dst_s1;
|
||||
let mut d = T::zero();
|
||||
unsafe {
|
||||
T::vec_dot(
|
||||
inp_cont.as_ptr(),
|
||||
k_cont.as_ptr(),
|
||||
&mut d,
|
||||
p.c_in,
|
||||
)
|
||||
}
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to
|
||||
// parallelise the different tasks so no two threads can try to
|
||||
// write at the same location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
|
Reference in New Issue
Block a user