Simplify the narrow implementation.

This commit is contained in:
laurent
2023-06-28 13:09:59 +01:00
parent c1bbbf94f6
commit 30b355ccd2
3 changed files with 36 additions and 34 deletions

View File

@ -1,4 +1,4 @@
use crate::Shape;
use crate::{Error, Result, Shape};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Layout {
@ -44,4 +44,25 @@ impl Layout {
pub fn is_fortran_contiguous(&self) -> bool {
self.shape.is_fortran_contiguous(&self.stride)
}
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
let dims = self.shape().dims();
if dim >= dims.len() {
Err(Error::UnexpectedNumberOfDims {
expected: dim + 1,
got: dims.len(),
shape: self.shape().clone(),
})?
}
if start + length > dims[dim] {
todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
}
let mut dims = dims.to_vec();
dims[dim] = length;
Ok(Self {
shape: Shape::from(dims),
stride: self.stride.clone(),
start_offset: self.start_offset + self.stride[dim] * start,
})
}
}