Add the permute op (similar to pytorch). (#504)

* Add the permute op (similar to pytorch).

* Add the backprop for dimension permutation.
This commit is contained in:
Laurent Mazare
2023-08-18 16:30:53 +01:00
committed by GitHub
parent 4f1541526c
commit cb069d6063
7 changed files with 85 additions and 4 deletions

View File

@ -1459,6 +1459,42 @@ impl Tensor {
Ok(Tensor(Arc::new(tensor_)))
}
/// Returns a tensor with the same data as the input where the dimensions have been permuted.
/// dims must be a permutation, i.e. include each dimension index exactly once.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
/// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
/// let tensor = tensor.permute((2, 3, 1, 0))?;
/// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
let dims = dims.to_indexes(self.shape(), "permute")?;
// O(n^2) permutation check but these arrays are small.
let is_permutation =
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
if !is_permutation {
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
dims
)
}
let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.permute(&dims)?,
op,
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
/// Returns true if the data is stored in a C contiguous (aka row major) way.
pub fn is_contiguous(&self) -> bool {
self.layout.is_contiguous()