mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Optimize the cpu conv2d kernel (#396)
* Conv2d simd optimization. * Fix the contiguous copying. * Small tweak.
This commit is contained in:
@ -1031,8 +1031,10 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||||
let l_out = p.l_out();
|
let l_out = p.l_out();
|
||||||
let dst_elems = p.c_out * l_out * p.b_size;
|
let dst_elems = p.c_out * l_out * p.b_size;
|
||||||
let mut dst = vec![T::zero(); dst_elems];
|
|
||||||
// The output shape is [b_size, c_out, l_out]
|
// The output shape is [b_size, c_out, l_out]
|
||||||
|
let mut dst = vec![T::zero(); dst_elems];
|
||||||
|
|
||||||
|
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||||
for b_idx in 0..p.b_size {
|
for b_idx in 0..p.b_size {
|
||||||
for src_l in 0..p.l_in {
|
for src_l in 0..p.l_in {
|
||||||
@ -1042,6 +1044,7 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for offset in 0..p.k_size {
|
for offset in 0..p.k_size {
|
||||||
for dst_c_idx in 0..p.c_out {
|
for dst_c_idx in 0..p.c_out {
|
||||||
let dst_idx = dst_c_idx * l_out;
|
let dst_idx = dst_c_idx * l_out;
|
||||||
@ -1073,13 +1076,7 @@ struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
|||||||
|
|
||||||
impl<'a> Map2 for Conv2D<'a> {
|
impl<'a> Map2 for Conv2D<'a> {
|
||||||
const OP: &'static str = "conv2d";
|
const OP: &'static str = "conv2d";
|
||||||
fn f<T: 'static + num_traits::NumAssign + Copy + std::fmt::Display>(
|
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||||
&self,
|
|
||||||
inp: &[T],
|
|
||||||
inp_l: &Layout,
|
|
||||||
k: &[T],
|
|
||||||
k_l: &Layout,
|
|
||||||
) -> Result<Vec<T>> {
|
|
||||||
let p = self.0;
|
let p = self.0;
|
||||||
let inp = &inp[inp_l.start_offset()..];
|
let inp = &inp[inp_l.start_offset()..];
|
||||||
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
||||||
@ -1087,43 +1084,67 @@ impl<'a> Map2 for Conv2D<'a> {
|
|||||||
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
|
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
|
||||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||||
|
|
||||||
|
// Output shape: [b_size, c_out, out_h, out_w].
|
||||||
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||||
|
|
||||||
|
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
|
||||||
|
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
|
||||||
|
let cont_s0 = p.i_h * p.i_w * p.c_in;
|
||||||
|
let cont_s1 = p.i_w * p.c_in;
|
||||||
|
let cont_s2 = p.c_in;
|
||||||
for b_idx in 0..p.b_size {
|
for b_idx in 0..p.b_size {
|
||||||
let inp_idx = b_idx * inp_s0;
|
for h_idx in 0..p.i_h {
|
||||||
let dst_idx = b_idx * p.c_out * out_h * out_w;
|
for w_idx in 0..p.i_w {
|
||||||
|
for c_idx in 0..p.c_in {
|
||||||
|
let src_idx =
|
||||||
|
b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
|
||||||
|
let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
|
||||||
|
inp_cont[dst_idx] = inp[src_idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for offset_h in 0..p.k_h {
|
||||||
|
for offset_w in 0..p.k_w {
|
||||||
for dst_c_idx in 0..p.c_out {
|
for dst_c_idx in 0..p.c_out {
|
||||||
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
|
let dst_idx = dst_c_idx * out_w * out_h;
|
||||||
|
let k_cont = (0..p.c_in)
|
||||||
|
.map(|c_in_idx| {
|
||||||
|
k[dst_c_idx * k_s0
|
||||||
|
+ c_in_idx * k_s1
|
||||||
|
+ offset_h * k_s2
|
||||||
|
+ offset_w * k_s3]
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for b_idx in 0..p.b_size {
|
||||||
|
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
|
||||||
for dst_h in 0..out_h {
|
for dst_h in 0..out_h {
|
||||||
let dst_idx = dst_idx + dst_h * out_w;
|
let dst_idx = dst_idx + dst_h * out_w;
|
||||||
for dst_w in 0..out_w {
|
|
||||||
let dst_idx = dst_idx + dst_w;
|
|
||||||
let mut d = T::zero();
|
|
||||||
for offset_h in 0..p.k_h {
|
|
||||||
let src_h = (p.stride * dst_h + offset_h)
|
let src_h = (p.stride * dst_h + offset_h)
|
||||||
.saturating_sub(p.padding)
|
.saturating_sub(p.padding)
|
||||||
.min(p.i_h - 1);
|
.min(p.i_h - 1);
|
||||||
for offset_w in 0..p.k_w {
|
for dst_w in 0..out_w {
|
||||||
|
let dst_idx = dst_idx + dst_w;
|
||||||
let src_w = (p.stride * dst_w + offset_w)
|
let src_w = (p.stride * dst_w + offset_w)
|
||||||
.saturating_sub(p.padding)
|
.saturating_sub(p.padding)
|
||||||
.min(p.i_w - 1);
|
.min(p.i_w - 1);
|
||||||
for src_c_idx in 0..p.c_in {
|
let inp_cont = &inp_cont
|
||||||
let inp_idx = inp_idx
|
[b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
|
||||||
+ src_c_idx * inp_s1
|
assert!(inp_cont.len() >= p.c_in);
|
||||||
+ src_h * inp_s2
|
assert!(k_cont.len() >= p.c_in);
|
||||||
+ src_w * inp_s3;
|
let mut d = T::zero();
|
||||||
let k_idx = dst_c_idx * k_s0
|
unsafe {
|
||||||
+ src_c_idx * k_s1
|
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
|
||||||
+ offset_h * k_s2
|
|
||||||
+ offset_w * k_s3;
|
|
||||||
d += inp[inp_idx] * k[k_idx]
|
|
||||||
}
|
}
|
||||||
}
|
dst[dst_idx] += d
|
||||||
}
|
|
||||||
dst[dst_idx] = d
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user