mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Abstract the implementation of Shape.
This commit is contained in:
13
src/shape.rs
13
src/shape.rs
@ -1,7 +1,7 @@
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Shape(pub(crate) Vec<usize>);
|
||||
pub struct Shape(Vec<usize>);
|
||||
|
||||
impl std::fmt::Debug for Shape {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@ -63,6 +63,12 @@ impl From<(usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! extract_dims {
|
||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
@ -142,6 +148,11 @@ impl Shape {
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
|
||||
self.0.extend(additional_dims);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
Reference in New Issue
Block a user