mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch). * Add the backprop for dimension permutation.
This commit is contained in:
@ -1459,6 +1459,42 @@ impl Tensor {
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns a tensor with the same data as the input where the dimensions have been permuted.
|
||||
/// dims must be a permutation, i.e. include each dimension index exactly once.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
|
||||
/// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
|
||||
/// let tensor = tensor.permute((2, 3, 1, 0))?;
|
||||
/// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
|
||||
let dims = dims.to_indexes(self.shape(), "permute")?;
|
||||
// O(n^2) permutation check but these arrays are small.
|
||||
let is_permutation =
|
||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
dims
|
||||
)
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.permute(&dims)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.layout.is_contiguous()
|
||||
|
Reference in New Issue
Block a user