mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Expose a couple layout methods. (#1816)
This commit is contained in:
@ -70,7 +70,7 @@ impl Layout {
|
|||||||
self.shape.is_fortran_contiguous(&self.stride)
|
self.shape.is_fortran_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||||
let dims = self.shape().dims();
|
let dims = self.shape().dims();
|
||||||
if dim >= dims.len() {
|
if dim >= dims.len() {
|
||||||
Err(Error::DimOutOfRange {
|
Err(Error::DimOutOfRange {
|
||||||
@ -99,7 +99,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||||
let rank = self.shape.rank();
|
let rank = self.shape.rank();
|
||||||
if rank <= dim1 || rank <= dim2 {
|
if rank <= dim1 || rank <= dim2 {
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -120,7 +120,7 @@ impl Layout {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||||
let is_permutation =
|
let is_permutation =
|
||||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||||
if !is_permutation {
|
if !is_permutation {
|
||||||
|
Reference in New Issue
Block a user