Add get_on_dim. (#1142)

This commit is contained in:
Laurent Mazare
2023-10-21 15:01:38 +01:00
committed by GitHub
parent 94e3373883
commit e8f760ee44

View File

@ -1598,6 +1598,24 @@ impl Tensor {
}
}
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let t = tensor.get_on_dim(1, 0)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
/// let t = tensor.get_on_dim(1, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
/// let t = tensor.get_on_dim(0, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
let dim = dim.to_index(self.shape(), "get_on_dim")?;
self.narrow(dim, index, 1)?.squeeze(dim)
}
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped.
///