mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Extract the strides in the conv ops. (#370)
This commit is contained in:
@ -992,19 +992,14 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let inp_stride = inp_l.stride();
|
||||
let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
|
||||
(inp_stride[0], &inp_stride[1..])
|
||||
} else {
|
||||
(0, inp_stride) // This value never gets used anyway
|
||||
};
|
||||
let k_stride = k_l.stride();
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
|
||||
let mut dst = vec![T::zero(); dst_elems];
|
||||
// The output shape is [b_size, c_out, l_out]
|
||||
for b_idx in 0..p.b_size.unwrap_or(1) {
|
||||
let inp_idx = b_idx * inp_stride0;
|
||||
let inp_idx = b_idx * inp_s0;
|
||||
let dst_idx = b_idx * p.c_out * l_out;
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
let dst_idx = dst_idx + dst_c_idx * l_out;
|
||||
@ -1016,11 +1011,8 @@ impl<'a> Map2 for Conv1D<'a> {
|
||||
.saturating_sub(p.padding)
|
||||
.min(p.l_in - 1);
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx =
|
||||
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
|
||||
let k_idx = dst_c_idx * k_stride[0]
|
||||
+ src_c_idx * k_stride[1]
|
||||
+ offset * k_stride[2];
|
||||
let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
|
||||
let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
|
||||
d += inp[inp_idx] * k[k_idx]
|
||||
}
|
||||
}
|
||||
@ -1045,14 +1037,14 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let inp_stride = inp_l.stride();
|
||||
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let k_stride = k_l.stride();
|
||||
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
|
||||
let (out_h, out_w) = (p.out_h(), p.out_w());
|
||||
|
||||
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
|
||||
for b_idx in 0..p.b_size {
|
||||
let inp_idx = b_idx * inp_stride[0];
|
||||
let inp_idx = b_idx * inp_s0;
|
||||
let dst_idx = b_idx * p.c_out * out_h * out_w;
|
||||
for dst_c_idx in 0..p.c_out {
|
||||
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
|
||||
@ -1071,13 +1063,13 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
.min(p.i_w - 1);
|
||||
for src_c_idx in 0..p.c_in {
|
||||
let inp_idx = inp_idx
|
||||
+ src_c_idx * inp_stride[1]
|
||||
+ src_h * inp_stride[2]
|
||||
+ src_w * inp_stride[3];
|
||||
let k_idx = dst_c_idx * k_stride[0]
|
||||
+ src_c_idx * k_stride[1]
|
||||
+ offset_h * k_stride[2]
|
||||
+ offset_w * k_stride[3];
|
||||
+ src_c_idx * inp_s1
|
||||
+ src_h * inp_s2
|
||||
+ src_w * inp_s3;
|
||||
let k_idx = dst_c_idx * k_s0
|
||||
+ src_c_idx * k_s1
|
||||
+ offset_h * k_s2
|
||||
+ offset_w * k_s3;
|
||||
d += inp[inp_idx] * k[k_idx]
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user