mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Tensor based indexing. (#842)
This commit is contained in:
@ -46,19 +46,31 @@ impl Tensor {
|
|||||||
current_dim += 1;
|
current_dim += 1;
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
TensorIndexer::IndexSelect(indexes) => {
|
||||||
|
if indexes.rank() != 1 {
|
||||||
|
crate::bail!("multi-dimensional tensor indexing is not supported")
|
||||||
|
}
|
||||||
|
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
|
||||||
|
current_dim += 1;
|
||||||
|
out
|
||||||
|
}
|
||||||
|
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
/// Generic structure used to index a slice of the tensor
|
/// Generic structure used to index a slice of the tensor
|
||||||
pub enum TensorIndexer {
|
pub enum TensorIndexer {
|
||||||
/// This selects the elemnts for which an index has some specific value.
|
/// This selects the elemnts for which an index has some specific value.
|
||||||
Select(usize),
|
Select(usize),
|
||||||
/// This is a regular slice, purely indexing a chunk of the tensor
|
/// This is a regular slice, purely indexing a chunk of the tensor
|
||||||
Narrow(Bound<usize>, Bound<usize>),
|
Narrow(Bound<usize>, Bound<usize>),
|
||||||
|
/// Indexing via a 1d tensor
|
||||||
|
IndexSelect(Tensor),
|
||||||
|
Err(Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<usize> for TensorIndexer {
|
impl From<usize> for TensorIndexer {
|
||||||
@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<&[u32]> for TensorIndexer {
|
||||||
|
fn from(index: &[u32]) -> Self {
|
||||||
|
match Tensor::new(index, &crate::Device::Cpu) {
|
||||||
|
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||||
|
Err(e) => TensorIndexer::Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<u32>> for TensorIndexer {
|
||||||
|
fn from(index: Vec<u32>) -> Self {
|
||||||
|
let len = index.len();
|
||||||
|
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
|
||||||
|
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
|
||||||
|
Err(e) => TensorIndexer::Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&Tensor> for TensorIndexer {
|
||||||
|
fn from(tensor: &Tensor) -> Self {
|
||||||
|
TensorIndexer::IndexSelect(tensor.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! impl_from_range {
|
macro_rules! impl_from_range {
|
||||||
($range_type:ty) => {
|
($range_type:ty) => {
|
||||||
impl From<$range_type> for TensorIndexer {
|
impl From<$range_type> for TensorIndexer {
|
||||||
|
Reference in New Issue
Block a user