mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
add roll function to tensor (#1666)
This commit is contained in:
@ -804,6 +804,34 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Roll the tensor input along the given dimension.
|
||||
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.roll(-1, 0)?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
||||
where
|
||||
D: Dim + Clone,
|
||||
{
|
||||
let dim_size = self.dim(dim.clone())?;
|
||||
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
||||
if shift == 0 {
|
||||
return Ok(self.clone());
|
||||
} else {
|
||||
let a = self.narrow(dim.clone(), 0, dim_size - shift)?;
|
||||
let b = self.narrow(dim.clone(), dim_size - shift, shift)?;
|
||||
Tensor::cat(&[&b, &a], dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
|
Reference in New Issue
Block a user