Extract the strides in the conv ops. (#370)

This commit is contained in:
Laurent Mazare
2023-08-09 18:57:05 +02:00
committed by GitHub
parent cd225bd3b1
commit 1892bd139c
2 changed files with 32 additions and 35 deletions

View File

@ -992,19 +992,14 @@ impl<'a> Map2 for Conv1D<'a> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
let inp_stride = inp_l.stride();
let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
(inp_stride[0], &inp_stride[1..])
} else {
(0, inp_stride) // This value never gets used anyway
};
let k_stride = k_l.stride();
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
for b_idx in 0..p.b_size.unwrap_or(1) {
let inp_idx = b_idx * inp_stride0;
let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * l_out;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * l_out;
@ -1016,11 +1011,8 @@ impl<'a> Map2 for Conv1D<'a> {
.saturating_sub(p.padding)
.min(p.l_in - 1);
for src_c_idx in 0..p.c_in {
let inp_idx =
inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
let k_idx = dst_c_idx * k_stride[0]
+ src_c_idx * k_stride[1]
+ offset * k_stride[2];
let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
d += inp[inp_idx] * k[k_idx]
}
}
@ -1045,14 +1037,14 @@ impl<'a> Map2 for Conv2D<'a> {
) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let inp_stride = inp_l.stride();
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
let k_stride = k_l.stride();
let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
let (out_h, out_w) = (p.out_h(), p.out_w());
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
for b_idx in 0..p.b_size {
let inp_idx = b_idx * inp_stride[0];
let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * out_h * out_w;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
@ -1071,13 +1063,13 @@ impl<'a> Map2 for Conv2D<'a> {
.min(p.i_w - 1);
for src_c_idx in 0..p.c_in {
let inp_idx = inp_idx
+ src_c_idx * inp_stride[1]
+ src_h * inp_stride[2]
+ src_w * inp_stride[3];
let k_idx = dst_c_idx * k_stride[0]
+ src_c_idx * k_stride[1]
+ offset_h * k_stride[2]
+ offset_w * k_stride[3];
+ src_c_idx * inp_s1
+ src_h * inp_s2
+ src_w * inp_s3;
let k_idx = dst_c_idx * k_s0
+ src_c_idx * k_s1
+ offset_h * k_s2
+ offset_w * k_s3;
d += inp[inp_idx] * k[k_idx]
}
}

View File

@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape {
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
if self.0.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
expected: $cnt,
got: self.0.len(),
shape: self.clone(),
}
.bt())
} else {
Ok($dims(&self.0))
pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
if dims.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
expected: $cnt,
got: dims.len(),
shape: Shape::from(dims),
}
.bt())
} else {
Ok($dims(dims))
}
}
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
$fn_name(self.0.as_slice())
}
}
impl crate::Tensor {
pub fn $fn_name(&self) -> Result<$out_type> {
self.shape().$fn_name()
@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(