mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Adding some doc + Extended stack
to work with extra final dimensions.
This commit is contained in:
@ -185,6 +185,7 @@ impl Shape {
|
||||
|
||||
pub trait Dim {
|
||||
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
|
||||
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
|
||||
}
|
||||
|
||||
impl Dim for usize {
|
||||
@ -200,6 +201,19 @@ impl Dim for usize {
|
||||
Ok(dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
||||
let dim = *self;
|
||||
if dim > shape.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
dim,
|
||||
op,
|
||||
})?
|
||||
} else {
|
||||
Ok(dim)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum D {
|
||||
@ -220,6 +234,19 @@ impl Dim for D {
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
||||
let rank = shape.rank();
|
||||
match self {
|
||||
Self::Minus1 if rank >= 1 => Ok(rank),
|
||||
Self::Minus2 if rank >= 2 => Ok(rank - 1),
|
||||
_ => Err(Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
dim: 42, // TODO: Have an adequate error
|
||||
op,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
Reference in New Issue
Block a user