mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Propagate the layout refactoring.
This commit is contained in:
@ -9,16 +9,20 @@ pub struct Layout {
|
||||
}
|
||||
|
||||
impl Layout {
|
||||
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
|
||||
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
Self {
|
||||
shape,
|
||||
stride,
|
||||
start_offset: 0,
|
||||
start_offset,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
|
||||
Self::contiguous_with_offset(shape, 0)
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
self.shape.dims()
|
||||
}
|
||||
@ -45,7 +49,7 @@ impl Layout {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
@ -65,4 +69,61 @@ impl Layout {
|
||||
start_offset: self.start_offset + self.stride[dim] * start,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
let rank = self.shape.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: usize::max(dim1, dim2),
|
||||
got: rank,
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
let mut stride = self.stride().to_vec();
|
||||
let mut dims = self.shape().dims().to_vec();
|
||||
dims.swap(dim1, dim2);
|
||||
stride.swap(dim1, dim2);
|
||||
Ok(Self {
|
||||
shape: Shape::from(dims),
|
||||
stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
})?
|
||||
}
|
||||
let added_dims = shape.rank() - self.shape().rank();
|
||||
let mut stride = vec![0; added_dims];
|
||||
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
||||
.iter()
|
||||
.zip(self.dims().iter().zip(self.stride()))
|
||||
{
|
||||
let s = if dst_dim == src_dim {
|
||||
src_stride
|
||||
} else if src_dim != 1 {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
} else {
|
||||
0
|
||||
};
|
||||
stride.push(s)
|
||||
}
|
||||
Ok(Self {
|
||||
shape,
|
||||
stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||
crate::StridedIndex::new(&self)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user