mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Very inefficient conv1d implementation.
This commit is contained in:
@ -206,18 +206,55 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||
|
||||
impl<'a> Map2 for Conv1D<'a> {
|
||||
const OP: &'static str = "conv1d";
|
||||
fn f<T: 'static + num_traits::Num + Copy>(
|
||||
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
||||
&self,
|
||||
_inp: &[T],
|
||||
_inp_l: &Layout,
|
||||
_k: &[T],
|
||||
_k_l: &Layout,
|
||||
inp: &[T],
|
||||
inp_l: &Layout,
|
||||
k: &[T],
|
||||
k_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
// TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let inp_stride = inp_l.stride();
|
||||
let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
|
||||
(inp_stride[0], &inp_stride[1..])
|
||||
} else {
|
||||
(0, inp_stride) // This value never gets used anyway
|
||||
};
|
||||
let k_stride = k_l.stride();
|
||||
let k_over_2 = p.k_size / 2;
|
||||
let l_out = p.l_out();
|
||||
let out_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let dst = vec![T::zero(); out_elems];
|
||||
// TODO: actually implement the ops.
|
||||
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let mut dst = vec![T::zero(); dst_elems];
|
||||
// The output shape is [b_size, c_out, l_out]
|
||||
for b_idx in 0..p.b_size.unwrap_or(1) {
|
||||
let inp_idx = b_idx * inp_stride0;
|
||||
let dst_idx = b_idx * p.c_out * l_out;
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
let dst_idx = dst_idx + dst_c_idx * l_out;
|
||||
for dst_l in 0..l_out {
|
||||
let dst_idx = dst_idx + dst_l;
|
||||
let mut d = T::zero();
|
||||
for offset in 0..p.k_size {
|
||||
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
||||
if k_over_2 <= dst_l + offset && dst_l + offset < k_over_2 + p.l_in {
|
||||
let src_l = dst_l + offset - k_over_2;
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx =
|
||||
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
|
||||
let k_idx = dst_c_idx * k_stride[0]
|
||||
+ src_c_idx * k_stride[1]
|
||||
+ offset * k_stride[2];
|
||||
d += inp[inp_idx] * k[k_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_idx] = d
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user