Parallelise the CPU kernels for the conv ops. (#401)

* Parallelise the conv2d op.

* Tighter control on threading.

* Also parallelise conv1d.

* Add some safety comment.
This commit is contained in:
Laurent Mazare
2023-08-11 06:51:58 +02:00
committed by GitHub
parent a325c1aa50
commit e29c7809ec
5 changed files with 64 additions and 8 deletions

View File

@ -1032,7 +1032,7 @@ impl<'a> Map2 for Conv1D<'a> {
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size;
// The output shape is [b_size, c_out, l_out]
let mut dst = vec![T::zero(); dst_elems];
let dst = vec![T::zero(); dst_elems];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
@ -1045,8 +1045,10 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset in 0..p.k_size {
for dst_c_idx in 0..p.c_out {
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@ -1063,10 +1065,17 @@ impl<'a> Map2 for Conv1D<'a> {
assert!(k_cont.len() >= p.c_in);
let mut d = T::zero();
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
dst[dst_idx] += d
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
// the different tasks so no two threads can try to write at the same
// location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
})
}
Ok(dst)
}
@ -1085,7 +1094,7 @@ impl<'a> Map2 for Conv2D<'a> {
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
@ -1105,9 +1114,11 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
let num_threads = crate::utils::get_num_threads();
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
for dst_c_idx in 0..p.c_out {
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@ -1137,11 +1148,18 @@ impl<'a> Map2 for Conv2D<'a> {
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
dst[dst_idx] += d
let dst_p = dst.as_ptr();
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
// the different tasks so no two threads can try to write at the same
// location.
unsafe {
let ptr = dst_p.add(dst_idx) as *mut T;
*ptr += d
}
}
}
}
}
});
}
}

View File

@ -26,3 +26,37 @@ impl VecDot for half::bf16 {}
impl VecDot for half::f16 {}
impl VecDot for u8 {}
impl VecDot for u32 {}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
func(0)
} else {
rayon::scope(|s| {
for thread_idx in 0..n_threads {
let func = &func;
s.spawn(move |_| func(thread_idx));
}
})
}
}
#[inline(always)]
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
for i in lo..up {
func(i)
}
} else {
rayon::scope(|s| {
for thread_idx in 0..n_threads {
let func = &func;
s.spawn(move |_| {
for i in (thread_idx..up).step_by(n_threads) {
func(i)
}
});
}
})
}
}

View File

@ -60,6 +60,8 @@ pub trait WithDType:
+ std::cmp::PartialOrd
+ std::fmt::Display
+ 'static
+ Send
+ Sync
+ crate::cpu_kernels::VecDot
{
const DTYPE: DType;