From d6447ad635bc450ef1f15ca7a4424c0f86e7a90a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 14 Sep 2023 08:47:07 +0200 Subject: [PATCH] Tensor based indexing. (#842) --- candle-core/src/indexer.rs | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 2b6d694b..7b84d316 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -46,19 +46,31 @@ impl Tensor { current_dim += 1; out } + TensorIndexer::IndexSelect(indexes) => { + if indexes.rank() != 1 { + crate::bail!("multi-dimensional tensor indexing is not supported") + } + let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?; + current_dim += 1; + out + } + TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"), }; } Ok(x) } } -#[derive(Debug, Clone)] +#[derive(Debug)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { /// This selects the elemnts for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound, Bound), + /// Indexing via a 1d tensor + IndexSelect(Tensor), + Err(Error), } impl From for TensorIndexer { @@ -67,6 +79,31 @@ impl From for TensorIndexer { } } +impl From<&[u32]> for TensorIndexer { + fn from(index: &[u32]) -> Self { + match Tensor::new(index, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From> for TensorIndexer { + fn from(index: Vec) -> Self { + let len = index.len(); + match Tensor::from_vec(index, len, &crate::Device::Cpu) { + Ok(tensor) => TensorIndexer::IndexSelect(tensor), + Err(e) => TensorIndexer::Err(e), + } + } +} + +impl From<&Tensor> for TensorIndexer { + fn from(tensor: &Tensor) -> Self { + TensorIndexer::IndexSelect(tensor.clone()) + } +} + macro_rules! impl_from_range { ($range_type:ty) => { impl From<$range_type> for TensorIndexer {