From b4d332cd1b749e09562907feb4a415d3d3bddd9f Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 11:32:27 +0100 Subject: [PATCH] Reduce code duplication. --- src/shape.rs | 95 +++++++++++++++++----------------------------------- 1 file changed, 31 insertions(+), 64 deletions(-) diff --git a/src/shape.rs b/src/shape.rs index 680460e8..d0fe0483 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -57,6 +57,22 @@ impl From<(usize, usize, usize)> for Shape { } } +macro_rules! extract_dims { + ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { + pub fn $fn_name(&self) -> Result<$out_type> { + if self.0.len() != $cnt { + Err(Error::UnexpectedNumberOfDims { + expected: $cnt, + got: self.0.len(), + shape: self.clone(), + }) + } else { + Ok($dims(&self.0)) + } + } + }; +} + impl Shape { pub fn from_dims(dims: &[usize]) -> Self { Self(dims.to_vec()) @@ -74,70 +90,21 @@ impl Shape { self.0.iter().product() } - pub fn r0(&self) -> Result<()> { - let shape = &self.0; - if shape.is_empty() { - Ok(()) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 0, - got: shape.len(), - shape: self.clone(), - }) - } - } - - pub fn r1(&self) -> Result { - let shape = &self.0; - if shape.len() == 1 { - Ok(shape[0]) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 1, - got: shape.len(), - shape: self.clone(), - }) - } - } - - pub fn r2(&self) -> Result<(usize, usize)> { - let shape = &self.0; - if shape.len() == 2 { - Ok((shape[0], shape[1])) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 2, - got: shape.len(), - shape: self.clone(), - }) - } - } - - pub fn r3(&self) -> Result<(usize, usize, usize)> { - let shape = &self.0; - if shape.len() == 3 { - Ok((shape[0], shape[1], shape[2])) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 3, - got: shape.len(), - shape: self.clone(), - }) - } - } - - pub fn r4(&self) -> Result<(usize, usize, usize, usize)> { - let shape = &self.0; - if shape.len() == 4 { - Ok((shape[0], shape[1], shape[2], shape[4])) - } else { - Err(Error::UnexpectedNumberOfDims { - expected: 4, - got: shape.len(), - shape: self.clone(), - }) - } - } + extract_dims!(r0, 0, |_: &Vec| (), ()); + extract_dims!(r1, 1, |d: &Vec| d[0], usize); + extract_dims!(r2, 2, |d: &Vec| (d[0], d[1]), (usize, usize)); + extract_dims!( + r3, + 3, + |d: &Vec| (d[0], d[1], d[2]), + (usize, usize, usize) + ); + extract_dims!( + r4, + 4, + |d: &Vec| (d[0], d[1], d[2], d[3]), + (usize, usize, usize, usize) + ); /// The strides given in number of elements for a contiguous n-dimensional /// arrays using this shape.