i(..) indexing sugar (partial).

- Only range, and select (no tensor_select)
- No negative indexing
This commit is contained in:
Nicolas Patry
2023-07-10 17:34:04 +02:00
parent 221b1aff65
commit ef0375d8bc
3 changed files with 249 additions and 1 deletions

244
candle-core/src/indexer.rs Normal file
View File

@ -0,0 +1,244 @@
use crate::{Error, Tensor};
use std::ops::{
Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
impl Tensor {
/// Intended to be use by the trait `.i()`
///
/// ```
/// # use candle::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = a.i(0..1)?;
/// assert_eq!(c.shape().dims(), &[1, 3]);
///
/// let c = a.i(0)?;
/// assert_eq!(c.shape().dims(), &[3]);
///
/// let c = a.i((.., ..2) )?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
///
/// let c = a.i((.., ..=2))?;
/// assert_eq!(c.shape().dims(), &[2, 3]);
///
/// # Ok::<(), candle::Error>(())
/// ```
fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
let mut x = self.clone();
let dims = self.shape().dims();
let mut current_dim = 0;
for (i, indexer) in indexers.iter().enumerate() {
x = match indexer {
TensorIndexer::Select(n) => x.get(*n)?,
TensorIndexer::Narrow(left_bound, right_bound) => {
let start = match left_bound {
Bound::Included(n) => *n,
Bound::Excluded(n) => *n + 1,
Bound::Unbounded => 0,
};
let stop = match right_bound {
Bound::Included(n) => *n + 1,
Bound::Excluded(n) => *n,
Bound::Unbounded => dims[i],
};
let len = stop - start;
println!(" indexer {indexer:?} Start {start} stop{stop} - {len:?}");
let out = x.narrow(current_dim, start, stop - start)?;
current_dim += 1;
out
}
};
}
Ok(x)
}
}
#[derive(Debug, Clone)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
// IndexSelect(Tensor),
}
impl From<usize> for TensorIndexer {
fn from(index: usize) -> Self {
TensorIndexer::Select(index)
}
}
// impl From<&[usize]> for TensorIndexer {
// fn from(index: &[usize]) -> Self {
// let tensor = index.into();
// TensorIndexer::IndexSelect(tensor)
// }
// }
//
// impl From<Vec<usize>> for TensorIndexer {
// fn from(index: Vec<usize>) -> Self {
// let tensor = Tensor::of_slice(&index);
// TensorIndexer::IndexSelect(tensor)
// }
// }
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
fn from(range: $range_type) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
};
}
impl_from_range!(Range<usize>);
impl_from_range!(RangeFrom<usize>);
impl_from_range!(RangeFull);
impl_from_range!(RangeInclusive<usize>);
impl_from_range!(RangeTo<usize>);
impl_from_range!(RangeToInclusive<usize>);
/// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor
pub trait IndexOp<T> {
/// Returns a slicing iterator which are the chunks of data necessary to
/// reconstruct the desired tensor.
fn i(&self, index: T) -> Result<Tensor, Error>;
}
impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}
impl<A> IndexOp<(A,)> for Tensor
where
A: Into<TensorIndexer>,
{
fn i(&self, index: (A,)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
self.index(&[idx_a])
}
}
impl<A, B> IndexOp<(A, B)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
{
fn i(&self, index: (A, B)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
self.index(&[idx_a, idx_b])
}
}
impl<A, B, C> IndexOp<(A, B, C)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
self.index(&[idx_a, idx_b, idx_c])
}
}
impl<A, B, C, D> IndexOp<(A, B, C, D)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
self.index(&[idx_a, idx_b, idx_c, idx_d])
}
}
impl<A, B, C, D, E> IndexOp<(A, B, C, D, E)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
E: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e])
}
}
impl<A, B, C, D, E, F> IndexOp<(A, B, C, D, E, F)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
E: Into<TensorIndexer>,
F: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E, F)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
let idx_f = index.5.into();
self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f])
}
}
impl<A, B, C, D, E, F, G> IndexOp<(A, B, C, D, E, F, G)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
E: Into<TensorIndexer>,
F: Into<TensorIndexer>,
G: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E, F, G)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
let idx_f = index.5.into();
let idx_g = index.6.into();
self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g])
}
}

View File

@ -43,6 +43,7 @@ pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod error;
mod indexer;
mod layout;
mod npy;
mod op;
@ -57,6 +58,7 @@ pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use shape::{Shape, D};
pub use storage::Storage;

View File

@ -444,6 +444,7 @@ impl Tensor {
len,
})?
}
println!("Narrow {start:?} - {} - {len} - {dims:?}", dims[dim]);
if start == 0 && dims[dim] == len {
Ok(self.clone())
} else {
@ -452,10 +453,11 @@ impl Tensor {
} else {
None
};
let layout = self.layout().narrow(dim, start, len)?;
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout().narrow(dim, start, len)?,
layout,
op,
is_variable: false,
};