From 684f66326d76658dd3abcc6210c225461744d1bc Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 17:39:58 +0100 Subject: [PATCH] Add the get method. --- candle-core/src/tensor.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b90bb2d6..cb968487 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -677,6 +677,15 @@ impl Tensor { self.flatten(None, None) } + pub fn get(&self, i: usize) -> Result { + let dims = self.dims(); + if dims.is_empty() { + Ok(self.clone()) + } else { + self.narrow(0, i, 1)?.reshape(&dims[1..]) + } + } + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. pub fn t(&self) -> Result {