mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Naive implementation for conv2d. (#341)
This commit is contained in:
@ -1039,12 +1039,59 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
const OP: &'static str = "conv2d";
|
||||
fn f<T: 'static + num_traits::NumAssign + Copy>(
|
||||
&self,
|
||||
_inp: &[T],
|
||||
_inp_l: &Layout,
|
||||
_k: &[T],
|
||||
_k_l: &Layout,
|
||||
inp: &[T],
|
||||
inp_l: &Layout,
|
||||
k: &[T],
|
||||
k_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
todo!()
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let inp_stride = inp_l.stride();
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let k_stride = k_l.stride();
|
||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||
|
||||
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
for b_idx in 0..p.b_size {
|
||||
let inp_idx = b_idx * inp_stride[0];
|
||||
let dst_idx = b_idx * p.c_out * out_h * out_w;
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
|
||||
for dst_h in 0..out_h {
|
||||
let dst_idx = dst_idx + dst_h * out_w;
|
||||
for dst_w in 0..out_h {
|
||||
let dst_idx = dst_idx + dst_w;
|
||||
let mut d = T::zero();
|
||||
for offset_h in 0..p.k_h {
|
||||
let src_h_plus = p.stride * dst_h + offset_h;
|
||||
if p.k_h / 2 <= src_h_plus && src_h_plus < p.k_h / 2 + p.i_h {
|
||||
let src_h = src_h_plus - p.k_h / 2;
|
||||
for offset_w in 0..p.k_w {
|
||||
let src_w_plus = p.stride * dst_w + offset_w;
|
||||
// inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset]
|
||||
if p.k_w / 2 <= src_w_plus && src_w_plus < p.k_w / 2 + p.i_w {
|
||||
let src_w = src_w_plus - p.k_w / 2;
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx = inp_idx
|
||||
+ src_c_idx * inp_stride[1]
|
||||
+ src_h * inp_stride[2]
|
||||
+ src_w * inp_stride[3];
|
||||
let k_idx = dst_c_idx * k_stride[0]
|
||||
+ src_c_idx * k_stride[1]
|
||||
+ offset_h * k_stride[2]
|
||||
+ offset_w * k_stride[3];
|
||||
d += inp[inp_idx] * k[k_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[dst_idx] = d
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user