mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Reduce code duplication.
This commit is contained in:
95
src/shape.rs
95
src/shape.rs
@ -57,6 +57,22 @@ impl From<(usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! extract_dims {
|
||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
if self.0.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: $cnt,
|
||||
got: self.0.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
} else {
|
||||
Ok($dims(&self.0))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn from_dims(dims: &[usize]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
@ -74,70 +90,21 @@ impl Shape {
|
||||
self.0.iter().product()
|
||||
}
|
||||
|
||||
pub fn r0(&self) -> Result<()> {
|
||||
let shape = &self.0;
|
||||
if shape.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 0,
|
||||
got: shape.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn r1(&self) -> Result<usize> {
|
||||
let shape = &self.0;
|
||||
if shape.len() == 1 {
|
||||
Ok(shape[0])
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: shape.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn r2(&self) -> Result<(usize, usize)> {
|
||||
let shape = &self.0;
|
||||
if shape.len() == 2 {
|
||||
Ok((shape[0], shape[1]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 2,
|
||||
got: shape.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn r3(&self) -> Result<(usize, usize, usize)> {
|
||||
let shape = &self.0;
|
||||
if shape.len() == 3 {
|
||||
Ok((shape[0], shape[1], shape[2]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 3,
|
||||
got: shape.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn r4(&self) -> Result<(usize, usize, usize, usize)> {
|
||||
let shape = &self.0;
|
||||
if shape.len() == 4 {
|
||||
Ok((shape[0], shape[1], shape[2], shape[4]))
|
||||
} else {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 4,
|
||||
got: shape.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(r1, 1, |d: &Vec<usize>| d[0], usize);
|
||||
extract_dims!(r2, 2, |d: &Vec<usize>| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
r3,
|
||||
3,
|
||||
|d: &Vec<usize>| (d[0], d[1], d[2]),
|
||||
(usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r4,
|
||||
4,
|
||||
|d: &Vec<usize>| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
/// The strides given in number of elements for a contiguous n-dimensional
|
||||
/// arrays using this shape.
|
||||
|
Reference in New Issue
Block a user