From e8f760ee44ad4b1f9f3606e36a1966df8509203b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 21 Oct 2023 15:01:38 +0100 Subject: [PATCH] Add get_on_dim. (#1142) --- candle-core/src/tensor.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index da47d180..0ffed2fe 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1598,6 +1598,24 @@ impl Tensor { } } + /// Returns the sub-tensor fixing the index at `index` on the dimension `dim`. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let t = tensor.get_on_dim(1, 0)?; + /// assert_eq!(t.to_vec1::()?, &[0., 2., 4.]); + /// let t = tensor.get_on_dim(1, 1)?; + /// assert_eq!(t.to_vec1::()?, &[1., 3., 5.]); + /// let t = tensor.get_on_dim(0, 1)?; + /// assert_eq!(t.to_vec1::()?, &[2., 3.]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn get_on_dim(&self, dim: D, index: usize) -> Result { + let dim = dim.to_index(self.shape(), "get_on_dim")?; + self.narrow(dim, index, 1)?.squeeze(dim) + } + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. ///