mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Start adding some ops.
This commit is contained in:
16
src/shape.rs
16
src/shape.rs
@ -1,4 +1,6 @@
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Shape(pub(crate) Vec<usize>);
|
||||
|
||||
impl std::fmt::Debug for Shape {
|
||||
@ -56,6 +58,10 @@ impl From<(usize, usize, usize)> for Shape {
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn from_dims(dims: &[usize]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
@ -76,7 +82,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 0,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -89,7 +95,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 1,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -102,7 +108,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 2,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -115,7 +121,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 3,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -128,7 +134,7 @@ impl Shape {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 4,
|
||||
got: shape.len(),
|
||||
shape: shape.to_vec(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user