mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add Shape try into (#189)
* Add the TryInto trait for shapes. * Use the vectorized operations in block mode too.
This commit is contained in:
@ -73,16 +73,24 @@ impl From<Vec<usize>> for Shape {
|
||||
|
||||
macro_rules! extract_dims {
|
||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
if self.0.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: $cnt,
|
||||
got: self.0.len(),
|
||||
shape: self.clone(),
|
||||
impl Shape {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
if self.0.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: $cnt,
|
||||
got: self.0.len(),
|
||||
shape: self.clone(),
|
||||
}
|
||||
.bt())
|
||||
} else {
|
||||
Ok($dims(&self.0))
|
||||
}
|
||||
.bt())
|
||||
} else {
|
||||
Ok($dims(&self.0))
|
||||
}
|
||||
}
|
||||
impl std::convert::TryInto<$out_type> for Shape {
|
||||
type Error = crate::Error;
|
||||
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
|
||||
self.$fn_name()
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -109,28 +117,6 @@ impl Shape {
|
||||
self.0.iter().product()
|
||||
}
|
||||
|
||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
r3,
|
||||
3,
|
||||
|d: &[usize]| (d[0], d[1], d[2]),
|
||||
(usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r4,
|
||||
4,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r5,
|
||||
5,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
/// The strides given in number of elements for a contiguous n-dimensional
|
||||
/// arrays using this shape.
|
||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||
@ -342,6 +328,28 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
r3,
|
||||
3,
|
||||
|d: &[usize]| (d[0], d[1], d[2]),
|
||||
(usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r4,
|
||||
4,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r5,
|
||||
5,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
Reference in New Issue
Block a user