Tensor based indexing. (#842)

This commit is contained in:
Laurent Mazare
2023-09-14 08:47:07 +02:00
committed by GitHub
parent 49d3f7f708
commit d6447ad635

View File

@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1; current_dim += 1;
out out
} }
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
}; };
} }
Ok(x) Ok(x)
} }
} }
#[derive(Debug, Clone)] #[derive(Debug)]
/// Generic structure used to index a slice of the tensor /// Generic structure used to index a slice of the tensor
pub enum TensorIndexer { pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value. /// This selects the elemnts for which an index has some specific value.
Select(usize), Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor /// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>), Narrow(Bound<usize>, Bound<usize>),
/// Indexing via a 1d tensor
IndexSelect(Tensor),
Err(Error),
} }
impl From<usize> for TensorIndexer { impl From<usize> for TensorIndexer {
@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
} }
} }
impl From<&[u32]> for TensorIndexer {
fn from(index: &[u32]) -> Self {
match Tensor::new(index, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
macro_rules! impl_from_range { macro_rules! impl_from_range {
($range_type:ty) => { ($range_type:ty) => {
impl From<$range_type> for TensorIndexer { impl From<$range_type> for TensorIndexer {