diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 3100c6e8..5d0f47c9 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -804,6 +804,34 @@ impl Tensor { } } + /// Roll the tensor input along the given dimension. + /// Elements that are shifted beyond the last position are re-introduced at the first position. + /// + /// ```rust + /// # use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.roll(1, 0)?; + /// assert_eq!(tensor.to_vec2::()?, &[[4., 5.], [0., 1.], [2., 3.]]); + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.roll(-1, 0)?; + /// assert_eq!(tensor.to_vec2::()?, &[[2., 3.], [4., 5.], [0., 1.]]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn roll(&self, shift: i32, dim: D) -> Result + where + D: Dim + Clone, + { + let dim_size = self.dim(dim.clone())?; + let shift = shift.rem_euclid(dim_size as i32) as usize; + if shift == 0 { + return Ok(self.clone()); + } else { + let a = self.narrow(dim.clone(), 0, dim_size - shift)?; + let b = self.narrow(dim.clone(), dim_size - shift, shift)?; + Tensor::cat(&[&b, &a], dim) + } + } + /// Returns the sum of all elements in the input tensor. The sum is performed over all the /// input dimensions. ///