From 1892bd139ce498f3f73436fc63b4fba87066bc8c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 9 Aug 2023 18:57:05 +0200 Subject: [PATCH] Extract the strides in the conv ops. (#370) --- candle-core/src/cpu_backend.rs | 38 ++++++++++++++-------------------- candle-core/src/shape.rs | 29 +++++++++++++++----------- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 155df1e9..1f94a9bc 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -992,19 +992,14 @@ impl<'a> Map2 for Conv1D<'a> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; let k = &k[k_l.start_offset()..]; - let inp_stride = inp_l.stride(); - let (inp_stride0, inp_stride) = if inp_stride.len() == 3 { - (inp_stride[0], &inp_stride[1..]) - } else { - (0, inp_stride) // This value never gets used anyway - }; - let k_stride = k_l.stride(); + let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; + let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let l_out = p.l_out(); let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); let mut dst = vec![T::zero(); dst_elems]; // The output shape is [b_size, c_out, l_out] for b_idx in 0..p.b_size.unwrap_or(1) { - let inp_idx = b_idx * inp_stride0; + let inp_idx = b_idx * inp_s0; let dst_idx = b_idx * p.c_out * l_out; for dst_c_idx in 0..p.c_out { let dst_idx = dst_idx + dst_c_idx * l_out; @@ -1016,11 +1011,8 @@ impl<'a> Map2 for Conv1D<'a> { .saturating_sub(p.padding) .min(p.l_in - 1); for src_c_idx in 0..p.c_in { - let inp_idx = - inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; - let k_idx = dst_c_idx * k_stride[0] - + src_c_idx * k_stride[1] - + offset * k_stride[2]; + let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2; + let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2; d += inp[inp_idx] * k[k_idx] } } @@ -1045,14 +1037,14 @@ impl<'a> Map2 for Conv2D<'a> { ) -> Result> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; - let inp_stride = inp_l.stride(); + let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?; let k = &k[k_l.start_offset()..]; - let k_stride = k_l.stride(); + let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?; let (out_h, out_w) = (p.out_h(), p.out_w()); let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; for b_idx in 0..p.b_size { - let inp_idx = b_idx * inp_stride[0]; + let inp_idx = b_idx * inp_s0; let dst_idx = b_idx * p.c_out * out_h * out_w; for dst_c_idx in 0..p.c_out { let dst_idx = dst_idx + dst_c_idx * out_h * out_w; @@ -1071,13 +1063,13 @@ impl<'a> Map2 for Conv2D<'a> { .min(p.i_w - 1); for src_c_idx in 0..p.c_in { let inp_idx = inp_idx - + src_c_idx * inp_stride[1] - + src_h * inp_stride[2] - + src_w * inp_stride[3]; - let k_idx = dst_c_idx * k_stride[0] - + src_c_idx * k_stride[1] - + offset_h * k_stride[2] - + offset_w * k_stride[3]; + + src_c_idx * inp_s1 + + src_h * inp_s2 + + src_w * inp_s3; + let k_idx = dst_c_idx * k_s0 + + src_c_idx * k_s1 + + offset_h * k_s2 + + offset_w * k_s3; d += inp[inp_idx] * k[k_idx] } } diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index a5e21aad..83d11c09 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -79,20 +79,25 @@ impl From> for Shape { macro_rules! extract_dims { ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { - impl Shape { - pub fn $fn_name(&self) -> Result<$out_type> { - if self.0.len() != $cnt { - Err(Error::UnexpectedNumberOfDims { - expected: $cnt, - got: self.0.len(), - shape: self.clone(), - } - .bt()) - } else { - Ok($dims(&self.0)) + pub fn $fn_name(dims: &[usize]) -> Result<$out_type> { + if dims.len() != $cnt { + Err(Error::UnexpectedNumberOfDims { + expected: $cnt, + got: dims.len(), + shape: Shape::from(dims), } + .bt()) + } else { + Ok($dims(dims)) } } + + impl Shape { + pub fn $fn_name(&self) -> Result<$out_type> { + $fn_name(self.0.as_slice()) + } + } + impl crate::Tensor { pub fn $fn_name(&self) -> Result<$out_type> { self.shape().$fn_name() @@ -340,7 +345,7 @@ impl Dims for (D1, D2, D3) { } } -extract_dims!(dims0, 0, |_: &Vec| (), ()); +extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!(