Add the permute op (similar to pytorch). (#504)

* Add the permute op (similar to pytorch).

* Add the backprop for dimension permutation.
This commit is contained in:
Laurent Mazare
2023-08-18 16:30:53 +01:00
committed by GitHub
parent 4f1541526c
commit cb069d6063
7 changed files with 85 additions and 4 deletions

View File

@ -112,6 +112,31 @@ impl Layout {
})
}
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
let is_permutation =
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
if !is_permutation {
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
idxs
)
}
let stride = self.stride();
let dims = self.shape().dims();
let mut perm_stride = stride.to_vec();
let mut perm_dims = dims.to_vec();
for (i, &idx) in idxs.iter().enumerate() {
perm_stride[i] = stride[idx];
perm_dims[i] = dims[idx];
}
Ok(Self {
shape: Shape::from(perm_dims),
stride: perm_stride,
start_offset: self.start_offset,
})
}
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into();
if shape.rank() < self.shape().rank() {