add roll function to tensor (#1666)

This commit is contained in:
Jiayu Liu
2024-02-06 15:49:45 +08:00
committed by GitHub
parent a83ca2ece0
commit 982722019b

View File

@ -804,6 +804,34 @@ impl Tensor {
}
}
/// Roll the tensor input along the given dimension.
/// Elements that are shifted beyond the last position are re-introduced at the first position.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(-1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
where
D: Dim + Clone,
{
let dim_size = self.dim(dim.clone())?;
let shift = shift.rem_euclid(dim_size as i32) as usize;
if shift == 0 {
return Ok(self.clone());
} else {
let a = self.narrow(dim.clone(), 0, dim_size - shift)?;
let b = self.narrow(dim.clone(), dim_size - shift, shift)?;
Tensor::cat(&[&b, &a], dim)
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions.
///