mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add get_on_dim. (#1142)
This commit is contained in:
@ -1598,6 +1598,24 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let t = tensor.get_on_dim(1, 0)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
|
||||
/// let t = tensor.get_on_dim(1, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
|
||||
/// let t = tensor.get_on_dim(0, 1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
|
||||
let dim = dim.to_index(self.shape(), "get_on_dim")?;
|
||||
self.narrow(dim, index, 1)?.squeeze(dim)
|
||||
}
|
||||
|
||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||
/// input are swapped.
|
||||
///
|
||||
|
Reference in New Issue
Block a user