mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Parallelise the CPU kernels for the conv ops. (#401)
* Parallelise the conv2d op. * Tighter control on threading. * Also parallelise conv1d. * Add some safety comment.
This commit is contained in:
@ -1032,7 +1032,7 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let l_out = p.l_out();
|
||||
let dst_elems = p.c_out * l_out * p.b_size;
|
||||
// The output shape is [b_size, c_out, l_out]
|
||||
let mut dst = vec![T::zero(); dst_elems];
|
||||
let dst = vec![T::zero(); dst_elems];
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||
@ -1045,8 +1045,10 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset in 0..p.k_size {
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * l_out;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
|
||||
@ -1063,10 +1065,17 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
assert!(k_cont.len() >= p.c_in);
|
||||
let mut d = T::zero();
|
||||
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
|
||||
dst[dst_idx] += d
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
|
||||
// the different tasks so no two threads can try to write at the same
|
||||
// location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
@ -1085,7 +1094,7 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||
|
||||
// Output shape: [b_size, c_out, out_h, out_w].
|
||||
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
|
||||
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
|
||||
@ -1105,9 +1114,11 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
let num_threads = crate::utils::get_num_threads();
|
||||
|
||||
for offset_h in 0..p.k_h {
|
||||
for offset_w in 0..p.k_w {
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
|
||||
let dst_idx = dst_c_idx * out_w * out_h;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| {
|
||||
@ -1137,11 +1148,18 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
unsafe {
|
||||
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||
}
|
||||
dst[dst_idx] += d
|
||||
let dst_p = dst.as_ptr();
|
||||
// Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
|
||||
// the different tasks so no two threads can try to write at the same
|
||||
// location.
|
||||
unsafe {
|
||||
let ptr = dst_p.add(dst_idx) as *mut T;
|
||||
*ptr += d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,3 +26,37 @@ impl VecDot for half::bf16 {}
|
||||
impl VecDot for half::f16 {}
|
||||
impl VecDot for u8 {}
|
||||
impl VecDot for u32 {}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
func(0)
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| func(thread_idx));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
|
||||
if n_threads == 1 {
|
||||
for i in lo..up {
|
||||
func(i)
|
||||
}
|
||||
} else {
|
||||
rayon::scope(|s| {
|
||||
for thread_idx in 0..n_threads {
|
||||
let func = &func;
|
||||
s.spawn(move |_| {
|
||||
for i in (thread_idx..up).step_by(n_threads) {
|
||||
func(i)
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -60,6 +60,8 @@ pub trait WithDType:
|
||||
+ std::cmp::PartialOrd
|
||||
+ std::fmt::Display
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ crate::cpu_kernels::VecDot
|
||||
{
|
||||
const DTYPE: DType;
|
||||
|
Reference in New Issue
Block a user