diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 5d0f47c9..5f0b6df9 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -821,13 +821,14 @@ impl Tensor { where D: Dim + Clone, { - let dim_size = self.dim(dim.clone())?; + let dim = dim.to_index(self.shape(), "roll")?; + let dim_size = self.dim(dim)?; let shift = shift.rem_euclid(dim_size as i32) as usize; if shift == 0 { - return Ok(self.clone()); + Ok(self.clone()) } else { - let a = self.narrow(dim.clone(), 0, dim_size - shift)?; - let b = self.narrow(dim.clone(), dim_size - shift, shift)?; + let a = self.narrow(dim, 0, dim_size - shift)?; + let b = self.narrow(dim, dim_size - shift, shift)?; Tensor::cat(&[&b, &a], dim) } }