mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Conv1d optimize (#392)
* Reorder the conv1d loops in the cpu backend. * Optimize the 1d convolution. * Conv1D optimize. * Fix some clippy lints.
This commit is contained in:
@ -1023,14 +1023,7 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||
|
||||
impl<'a> Map2 for Conv1D<'a> {
|
||||
const OP: &'static str = "conv1d";
|
||||
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
||||
&self,
|
||||
inp: &[T],
|
||||
inp_l: &Layout,
|
||||
k: &[T],
|
||||
k_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
// TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
@ -1040,25 +1033,35 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let dst_elems = p.c_out * l_out * p.b_size;
|
||||
let mut dst = vec![T::zero(); dst_elems];
|
||||
// The output shape is [b_size, c_out, l_out]
|
||||
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
|
||||
for b_idx in 0..p.b_size {
|
||||
let inp_idx = b_idx * inp_s0;
|
||||
let dst_idx = b_idx * p.c_out * l_out;
|
||||
for src_l in 0..p.l_in {
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
|
||||
inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
for offset in 0..p.k_size {
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
let dst_idx = dst_idx + dst_c_idx * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let mut d = T::zero();
|
||||
for offset in 0..p.k_size {
|
||||
let dst_idx = dst_c_idx * l_out;
|
||||
let k_cont = (0..p.c_in)
|
||||
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
|
||||
.collect::<Vec<_>>();
|
||||
for b_idx in 0..p.b_size {
|
||||
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let src_l = (p.stride * dst_l + offset)
|
||||
.saturating_sub(p.padding)
|
||||
.min(p.l_in - 1);
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
|
||||
let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
|
||||
d += inp[inp_idx] * k[k_idx]
|
||||
}
|
||||
let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
|
||||
assert!(inp_cont.len() >= p.c_in);
|
||||
assert!(k_cont.len() >= p.c_in);
|
||||
let mut d = T::zero();
|
||||
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
|
||||
dst[dst_idx] += d
|
||||
}
|
||||
dst[dst_idx] = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user