mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Reduce code duplication.
This commit is contained in:
95
src/shape.rs
95
src/shape.rs
@ -57,6 +57,22 @@ impl From<(usize, usize, usize)> for Shape {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! extract_dims {
|
||||||
|
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||||
|
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||||
|
if self.0.len() != $cnt {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: $cnt,
|
||||||
|
got: self.0.len(),
|
||||||
|
shape: self.clone(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok($dims(&self.0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
impl Shape {
|
impl Shape {
|
||||||
pub fn from_dims(dims: &[usize]) -> Self {
|
pub fn from_dims(dims: &[usize]) -> Self {
|
||||||
Self(dims.to_vec())
|
Self(dims.to_vec())
|
||||||
@ -74,70 +90,21 @@ impl Shape {
|
|||||||
self.0.iter().product()
|
self.0.iter().product()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn r0(&self) -> Result<()> {
|
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||||
let shape = &self.0;
|
extract_dims!(r1, 1, |d: &Vec<usize>| d[0], usize);
|
||||||
if shape.is_empty() {
|
extract_dims!(r2, 2, |d: &Vec<usize>| (d[0], d[1]), (usize, usize));
|
||||||
Ok(())
|
extract_dims!(
|
||||||
} else {
|
r3,
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
3,
|
||||||
expected: 0,
|
|d: &Vec<usize>| (d[0], d[1], d[2]),
|
||||||
got: shape.len(),
|
(usize, usize, usize)
|
||||||
shape: self.clone(),
|
);
|
||||||
})
|
extract_dims!(
|
||||||
}
|
r4,
|
||||||
}
|
4,
|
||||||
|
|d: &Vec<usize>| (d[0], d[1], d[2], d[3]),
|
||||||
pub fn r1(&self) -> Result<usize> {
|
(usize, usize, usize, usize)
|
||||||
let shape = &self.0;
|
);
|
||||||
if shape.len() == 1 {
|
|
||||||
Ok(shape[0])
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 1,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn r2(&self) -> Result<(usize, usize)> {
|
|
||||||
let shape = &self.0;
|
|
||||||
if shape.len() == 2 {
|
|
||||||
Ok((shape[0], shape[1]))
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 2,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn r3(&self) -> Result<(usize, usize, usize)> {
|
|
||||||
let shape = &self.0;
|
|
||||||
if shape.len() == 3 {
|
|
||||||
Ok((shape[0], shape[1], shape[2]))
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 3,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn r4(&self) -> Result<(usize, usize, usize, usize)> {
|
|
||||||
let shape = &self.0;
|
|
||||||
if shape.len() == 4 {
|
|
||||||
Ok((shape[0], shape[1], shape[2], shape[4]))
|
|
||||||
} else {
|
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: 4,
|
|
||||||
got: shape.len(),
|
|
||||||
shape: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The strides given in number of elements for a contiguous n-dimensional
|
/// The strides given in number of elements for a contiguous n-dimensional
|
||||||
/// arrays using this shape.
|
/// arrays using this shape.
|
||||||
|
Reference in New Issue
Block a user