From ef0375d8bcfcf2ec335a12a04d793d784dfca45d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 10 Jul 2023 17:34:04 +0200 Subject: [PATCH] `i(..)` indexing sugar (partial). - Only range, and select (no tensor_select) - No negative indexing --- candle-core/src/indexer.rs | 244 +++++++++++++++++++++++++++++++++++++ candle-core/src/lib.rs | 2 + candle-core/src/tensor.rs | 4 +- 3 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 candle-core/src/indexer.rs diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs new file mode 100644 index 00000000..44c38206 --- /dev/null +++ b/candle-core/src/indexer.rs @@ -0,0 +1,244 @@ +use crate::{Error, Tensor}; +use std::ops::{ + Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, +}; + +impl Tensor { + /// Intended to be use by the trait `.i()` + /// + /// ``` + /// # use candle::{Tensor, DType, Device, IndexOp}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = a.i(0..1)?; + /// assert_eq!(c.shape().dims(), &[1, 3]); + /// + /// let c = a.i(0)?; + /// assert_eq!(c.shape().dims(), &[3]); + /// + /// let c = a.i((.., ..2) )?; + /// assert_eq!(c.shape().dims(), &[2, 2]); + /// + /// let c = a.i((.., ..=2))?; + /// assert_eq!(c.shape().dims(), &[2, 3]); + /// + /// # Ok::<(), candle::Error>(()) + /// ``` + fn index(&self, indexers: &[TensorIndexer]) -> Result { + let mut x = self.clone(); + let dims = self.shape().dims(); + let mut current_dim = 0; + for (i, indexer) in indexers.iter().enumerate() { + x = match indexer { + TensorIndexer::Select(n) => x.get(*n)?, + TensorIndexer::Narrow(left_bound, right_bound) => { + let start = match left_bound { + Bound::Included(n) => *n, + Bound::Excluded(n) => *n + 1, + Bound::Unbounded => 0, + }; + let stop = match right_bound { + Bound::Included(n) => *n + 1, + Bound::Excluded(n) => *n, + Bound::Unbounded => dims[i], + }; + let len = stop - start; + println!(" indexer {indexer:?} Start {start} stop{stop} - {len:?}"); + let out = x.narrow(current_dim, start, stop - start)?; + current_dim += 1; + out + } + }; + } + Ok(x) + } +} + +#[derive(Debug, Clone)] +/// Generic structure used to index a slice of the tensor +pub enum TensorIndexer { + Select(usize), + /// This is a regular slice, purely indexing a chunk of the tensor + Narrow(Bound, Bound), + // IndexSelect(Tensor), +} + +impl From for TensorIndexer { + fn from(index: usize) -> Self { + TensorIndexer::Select(index) + } +} + +// impl From<&[usize]> for TensorIndexer { +// fn from(index: &[usize]) -> Self { +// let tensor = index.into(); +// TensorIndexer::IndexSelect(tensor) +// } +// } +// +// impl From> for TensorIndexer { +// fn from(index: Vec) -> Self { +// let tensor = Tensor::of_slice(&index); +// TensorIndexer::IndexSelect(tensor) +// } +// } + +macro_rules! impl_from_range { + ($range_type:ty) => { + impl From<$range_type> for TensorIndexer { + fn from(range: $range_type) -> Self { + use std::ops::Bound::*; + + let start = match range.start_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + + let end = match range.end_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + + TensorIndexer::Narrow(start, end) + } + } + }; +} + +impl_from_range!(Range); +impl_from_range!(RangeFrom); +impl_from_range!(RangeFull); +impl_from_range!(RangeInclusive); +impl_from_range!(RangeTo); +impl_from_range!(RangeToInclusive); + +/// Trait used to implement multiple signatures for ease of use of the slicing +/// of a tensor +pub trait IndexOp { + /// Returns a slicing iterator which are the chunks of data necessary to + /// reconstruct the desired tensor. + fn i(&self, index: T) -> Result; +} + +impl IndexOp for Tensor +where + T: Into, +{ + fn i(&self, index: T) -> Result { + self.index(&[index.into()]) + } +} + +impl IndexOp<(A,)> for Tensor +where + A: Into, +{ + fn i(&self, index: (A,)) -> Result { + let idx_a = index.0.into(); + self.index(&[idx_a]) + } +} + +impl IndexOp<(A, B)> for Tensor +where + A: Into, + B: Into, +{ + fn i(&self, index: (A, B)) -> Result { + let idx_a = index.0.into(); + let idx_b = index.1.into(); + self.index(&[idx_a, idx_b]) + } +} + +impl IndexOp<(A, B, C)> for Tensor +where + A: Into, + B: Into, + C: Into, +{ + fn i(&self, index: (A, B, C)) -> Result { + let idx_a = index.0.into(); + let idx_b = index.1.into(); + let idx_c = index.2.into(); + self.index(&[idx_a, idx_b, idx_c]) + } +} + +impl IndexOp<(A, B, C, D)> for Tensor +where + A: Into, + B: Into, + C: Into, + D: Into, +{ + fn i(&self, index: (A, B, C, D)) -> Result { + let idx_a = index.0.into(); + let idx_b = index.1.into(); + let idx_c = index.2.into(); + let idx_d = index.3.into(); + self.index(&[idx_a, idx_b, idx_c, idx_d]) + } +} + +impl IndexOp<(A, B, C, D, E)> for Tensor +where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, +{ + fn i(&self, index: (A, B, C, D, E)) -> Result { + let idx_a = index.0.into(); + let idx_b = index.1.into(); + let idx_c = index.2.into(); + let idx_d = index.3.into(); + let idx_e = index.4.into(); + self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e]) + } +} + +impl IndexOp<(A, B, C, D, E, F)> for Tensor +where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, + F: Into, +{ + fn i(&self, index: (A, B, C, D, E, F)) -> Result { + let idx_a = index.0.into(); + let idx_b = index.1.into(); + let idx_c = index.2.into(); + let idx_d = index.3.into(); + let idx_e = index.4.into(); + let idx_f = index.5.into(); + self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f]) + } +} + +impl IndexOp<(A, B, C, D, E, F, G)> for Tensor +where + A: Into, + B: Into, + C: Into, + D: Into, + E: Into, + F: Into, + G: Into, +{ + fn i(&self, index: (A, B, C, D, E, F, G)) -> Result { + let idx_a = index.0.into(); + let idx_b = index.1.into(); + let idx_c = index.2.into(); + let idx_d = index.3.into(); + let idx_e = index.4.into(); + let idx_f = index.5.into(); + let idx_g = index.6.into(); + self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g]) + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 8df44e37..149e8ecc 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -43,6 +43,7 @@ pub mod display; mod dtype; mod dummy_cuda_backend; mod error; +mod indexer; mod layout; mod npy; mod op; @@ -57,6 +58,7 @@ pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; +pub use indexer::IndexOp; pub use layout::Layout; pub use shape::{Shape, D}; pub use storage::Storage; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 9b0681e0..bcd380aa 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -444,6 +444,7 @@ impl Tensor { len, })? } + println!("Narrow {start:?} - {} - {len} - {dims:?}", dims[dim]); if start == 0 && dims[dim] == len { Ok(self.clone()) } else { @@ -452,10 +453,11 @@ impl Tensor { } else { None }; + let layout = self.layout().narrow(dim, start, len)?; let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - layout: self.layout().narrow(dim, start, len)?, + layout, op, is_variable: false, };