mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch). * Add the backprop for dimension permutation.
This commit is contained in:
@ -112,6 +112,31 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
idxs
|
||||
)
|
||||
}
|
||||
let stride = self.stride();
|
||||
let dims = self.shape().dims();
|
||||
let mut perm_stride = stride.to_vec();
|
||||
let mut perm_dims = dims.to_vec();
|
||||
for (i, &idx) in idxs.iter().enumerate() {
|
||||
perm_stride[i] = stride[idx];
|
||||
perm_dims[i] = dims[idx];
|
||||
}
|
||||
Ok(Self {
|
||||
shape: Shape::from(perm_dims),
|
||||
stride: perm_stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
|
Reference in New Issue
Block a user