Conv1d optimize (#392)

* Reorder the conv1d loops in the cpu backend.

* Optimize the 1d convolution.

* Conv1D optimize.

* Fix some clippy lints.
This commit is contained in:
Laurent Mazare
2023-08-10 16:23:52 +02:00
committed by GitHub
parent 0b0fa56978
commit c8039579a5
6 changed files with 62 additions and 22 deletions

View File

@ -1023,14 +1023,7 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
const OP: &'static str = "conv1d";
fn f<T: 'static + num_traits::NumAssign + Copy>(
&self,
inp: &[T],
inp_l: &Layout,
k: &[T],
k_l: &Layout,
) -> Result<Vec<T>> {
// TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
@ -1040,25 +1033,35 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
for b_idx in 0..p.b_size {
let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * l_out;
for src_l in 0..p.l_in {
for src_c_idx in 0..p.c_in {
let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
}
}
}
for offset in 0..p.k_size {
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let mut d = T::zero();
for offset in 0..p.k_size {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
.collect::<Vec<_>>();
for b_idx in 0..p.b_size {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let src_l = (p.stride * dst_l + offset)
.saturating_sub(p.padding)
.min(p.l_in - 1);
for src_c_idx in 0..p.c_in {
let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
d += inp[inp_idx] * k[k_idx]
}
let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
assert!(inp_cont.len() >= p.c_in);
assert!(k_cont.len() >= p.c_in);
let mut d = T::zero();
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
dst[dst_idx] += d
}
dst[dst_idx] = d
}
}
}