Parallelise the CPU kernels for the conv ops. (#401)

* Parallelise the conv2d op.

* Tighter control on threading.

* Also parallelise conv1d.

* Add some safety comment.
This commit is contained in:
Laurent Mazare
2023-08-11 06:51:58 +02:00
committed by GitHub
parent a325c1aa50
commit e29c7809ec
5 changed files with 64 additions and 8 deletions

View File

@ -1032,7 +1032,7 @@ impl<'a> Map2 for Conv1D<'a> {
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size;
// The output shape is [b_size, c_out, l_out]
let mut dst = vec![T::zero(); dst_elems];
let dst = vec![T::zero(); dst_elems];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
@ -1045,8 +1045,10 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset in 0..p.k_size {
for dst_c_idx in 0..p.c_out {
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@ -1063,10 +1065,17 @@ impl<'a> Map2 for Conv1D<'a> {
assert!(k_cont.len() >= p.c_in);
let mut d = T::zero();
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
dst[dst_idx] += d
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
// the different tasks so no two threads can try to write at the same
// location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
})
}
Ok(dst)
}
@ -1085,7 +1094,7 @@ impl<'a> Map2 for Conv2D<'a> {
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
@ -1105,9 +1114,11 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
for dst_c_idx in 0..p.c_out {
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@ -1137,11 +1148,18 @@ impl<'a> Map2 for Conv2D<'a> {
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
dst[dst_idx] += d
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
// the different tasks so no two threads can try to write at the same
// location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
}
});
}
}