Propagate the layout refactoring.

This commit is contained in:
laurent
2023-06-28 13:42:23 +01:00
parent 30b355ccd2
commit 303b853098
5 changed files with 130 additions and 129 deletions

View File

@ -9,16 +9,20 @@ pub struct Layout {
}
impl Layout {
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
let shape = shape.into();
let stride = shape.stride_contiguous();
Self {
shape,
stride,
start_offset: 0,
start_offset,
}
}
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
Self::contiguous_with_offset(shape, 0)
}
pub fn dims(&self) -> &[usize] {
self.shape.dims()
}
@ -45,7 +49,7 @@ impl Layout {
self.shape.is_fortran_contiguous(&self.stride)
}
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
let dims = self.shape().dims();
if dim >= dims.len() {
Err(Error::UnexpectedNumberOfDims {
@ -65,4 +69,61 @@ impl Layout {
start_offset: self.start_offset + self.stride[dim] * start,
})
}
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
let rank = self.shape.rank();
if rank <= dim1 || rank <= dim2 {
return Err(Error::UnexpectedNumberOfDims {
expected: usize::max(dim1, dim2),
got: rank,
shape: self.shape().clone(),
});
}
let mut stride = self.stride().to_vec();
let mut dims = self.shape().dims().to_vec();
dims.swap(dim1, dim2);
stride.swap(dim1, dim2);
Ok(Self {
shape: Shape::from(dims),
stride,
start_offset: self.start_offset,
})
}
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into();
if shape.rank() < self.shape().rank() {
Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
})?
}
let added_dims = shape.rank() - self.shape().rank();
let mut stride = vec![0; added_dims];
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
.iter()
.zip(self.dims().iter().zip(self.stride()))
{
let s = if dst_dim == src_dim {
src_stride
} else if src_dim != 1 {
return Err(Error::BroadcastIncompatibleShapes {
src_shape: self.shape().clone(),
dst_shape: shape,
});
} else {
0
};
stride.push(s)
}
Ok(Self {
shape,
stride,
start_offset: self.start_offset,
})
}
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(&self)
}
}