Reduce code duplication.

This commit is contained in:
laurent
2023-06-20 11:32:27 +01:00
parent f5b0aa815a
commit b4d332cd1b

View File

@ -57,6 +57,22 @@ impl From<(usize, usize, usize)> for Shape {
}
}
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
pub fn $fn_name(&self) -> Result<$out_type> {
if self.0.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
expected: $cnt,
got: self.0.len(),
shape: self.clone(),
})
} else {
Ok($dims(&self.0))
}
}
};
}
impl Shape {
pub fn from_dims(dims: &[usize]) -> Self {
Self(dims.to_vec())
@ -74,70 +90,21 @@ impl Shape {
self.0.iter().product()
}
pub fn r0(&self) -> Result<()> {
let shape = &self.0;
if shape.is_empty() {
Ok(())
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 0,
got: shape.len(),
shape: self.clone(),
})
}
}
pub fn r1(&self) -> Result<usize> {
let shape = &self.0;
if shape.len() == 1 {
Ok(shape[0])
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 1,
got: shape.len(),
shape: self.clone(),
})
}
}
pub fn r2(&self) -> Result<(usize, usize)> {
let shape = &self.0;
if shape.len() == 2 {
Ok((shape[0], shape[1]))
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 2,
got: shape.len(),
shape: self.clone(),
})
}
}
pub fn r3(&self) -> Result<(usize, usize, usize)> {
let shape = &self.0;
if shape.len() == 3 {
Ok((shape[0], shape[1], shape[2]))
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 3,
got: shape.len(),
shape: self.clone(),
})
}
}
pub fn r4(&self) -> Result<(usize, usize, usize, usize)> {
let shape = &self.0;
if shape.len() == 4 {
Ok((shape[0], shape[1], shape[2], shape[4]))
} else {
Err(Error::UnexpectedNumberOfDims {
expected: 4,
got: shape.len(),
shape: self.clone(),
})
}
}
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
extract_dims!(r1, 1, |d: &Vec<usize>| d[0], usize);
extract_dims!(r2, 2, |d: &Vec<usize>| (d[0], d[1]), (usize, usize));
extract_dims!(
r3,
3,
|d: &Vec<usize>| (d[0], d[1], d[2]),
(usize, usize, usize)
);
extract_dims!(
r4,
4,
|d: &Vec<usize>| (d[0], d[1], d[2], d[3]),
(usize, usize, usize, usize)
);
/// The strides given in number of elements for a contiguous n-dimensional
/// arrays using this shape.