diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b90bb2d6..cb968487 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -677,6 +677,15 @@ impl Tensor { self.flatten(None, None) } + pub fn get(&self, i: usize) -> Result { + let dims = self.dims(); + if dims.is_empty() { + Ok(self.clone()) + } else { + self.narrow(0, i, 1)?.reshape(&dims[1..]) + } + } + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. pub fn t(&self) -> Result {