mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Tensor based indexing. (#842)
This commit is contained in:
@ -46,19 +46,31 @@ impl Tensor {
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
TensorIndexer::IndexSelect(indexes) => {
|
||||
if indexes.rank() != 1 {
|
||||
crate::bail!("multi-dimensional tensor indexing is not supported")
|
||||
}
|
||||
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
|
||||
};
|
||||
}
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
/// Generic structure used to index a slice of the tensor
|
||||
pub enum TensorIndexer {
|
||||
/// This selects the elemnts for which an index has some specific value.
|
||||
Select(usize),
|
||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||
Narrow(Bound<usize>, Bound<usize>),
|
||||
/// Indexing via a 1d tensor
|
||||
IndexSelect(Tensor),
|
||||
Err(Error),
|
||||
}
|
||||
|
||||
impl From<usize> for TensorIndexer {
|
||||
@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u32]> for TensorIndexer {
|
||||
fn from(index: &[u32]) -> Self {
|
||||
match Tensor::new(index, &crate::Device::Cpu) {
|
||||
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||
Err(e) => TensorIndexer::Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u32>> for TensorIndexer {
|
||||
fn from(index: Vec<u32>) -> Self {
|
||||
let len = index.len();
|
||||
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
|
||||
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||
Err(e) => TensorIndexer::Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Tensor> for TensorIndexer {
|
||||
fn from(tensor: &Tensor) -> Self {
|
||||
TensorIndexer::IndexSelect(tensor.clone())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_from_range {
|
||||
($range_type:ty) => {
|
||||
impl From<$range_type> for TensorIndexer {
|
||||
|
Reference in New Issue
Block a user