From 8aab78738496e6836a93a4ff1d0346af9c12898f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 12 Jul 2023 15:42:36 +0100 Subject: [PATCH] Test the index op + bugfix. (#148) --- candle-core/src/indexer.rs | 18 +------ candle-core/tests/indexing_tests.rs | 82 +++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 16 deletions(-) create mode 100644 candle-core/tests/indexing_tests.rs diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 0651b791..c7a9a140 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -30,7 +30,7 @@ impl Tensor { let mut current_dim = 0; for (i, indexer) in indexers.iter().enumerate() { x = match indexer { - TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?, + TensorIndexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?, TensorIndexer::Narrow(left_bound, right_bound) => { let start = match left_bound { Bound::Included(n) => *n, @@ -55,10 +55,10 @@ impl Tensor { #[derive(Debug, Clone)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { + /// This selects the elemnts for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound, Bound), - // IndexSelect(Tensor), } impl From for TensorIndexer { @@ -67,20 +67,6 @@ impl From for TensorIndexer { } } -// impl From<&[usize]> for TensorIndexer { -// fn from(index: &[usize]) -> Self { -// let tensor = index.into(); -// TensorIndexer::IndexSelect(tensor) -// } -// } -// -// impl From> for TensorIndexer { -// fn from(index: Vec) -> Self { -// let tensor = Tensor::of_slice(&index); -// TensorIndexer::IndexSelect(tensor) -// } -// } - macro_rules! impl_from_range { ($range_type:ty) => { impl From<$range_type> for TensorIndexer { diff --git a/candle-core/tests/indexing_tests.rs b/candle-core/tests/indexing_tests.rs new file mode 100644 index 00000000..6a71b8fb --- /dev/null +++ b/candle-core/tests/indexing_tests.rs @@ -0,0 +1,82 @@ +use anyhow::Result; +use candle::{Device, IndexOp, Tensor}; + +mod test_utils; + +#[test] +fn integer_index() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((2, 3))?; + let result = tensor.i(1)?; + assert_eq!(result.dims(), &[3]); + assert_eq!(result.to_vec1::()?, &[3, 4, 5]); + + let result = tensor.i((.., 2))?; + assert_eq!(result.dims(), &[2]); + assert_eq!(result.to_vec1::()?, &[2, 5]); + + Ok(()) +} + +#[test] +fn range_index() -> Result<()> { + let dev = Device::Cpu; + // RangeFull + let tensor = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((2, 3))?; + let result = tensor.i(..)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2], [3, 4, 5]]); + + // Range + let tensor = Tensor::arange(0u32, 4 * 3, &dev)?.reshape((4, 3))?; + let result = tensor.i(1..3)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[3, 4, 5], [6, 7, 8]]); + + // RangeFrom + let result = tensor.i(2..)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[6, 7, 8], [9, 10, 11]]); + + // RangeTo + let result = tensor.i(..2)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2], [3, 4, 5]]); + + // RangeInclusive + let result = tensor.i(1..=2)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[3, 4, 5], [6, 7, 8]]); + + // RangeTo + let result = tensor.i(..1)?; + assert_eq!(result.dims(), &[1, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2]]); + + // RangeToInclusive + let result = tensor.i(..=1)?; + assert_eq!(result.dims(), &[2, 3]); + assert_eq!(result.to_vec2::()?, &[[0, 1, 2], [3, 4, 5]]); + Ok(()) +} + +#[test] +fn index_3d() -> Result<()> { + let tensor = Tensor::from_iter(0..24u32, &Device::Cpu)?.reshape((2, 3, 4))?; + assert_eq!(tensor.i((0, 0, 0))?.to_scalar::()?, 0); + assert_eq!(tensor.i((1, 0, 0))?.to_scalar::()?, 12); + assert_eq!(tensor.i((0, 1, 0))?.to_scalar::()?, 4); + assert_eq!(tensor.i((0, 1, 3))?.to_scalar::()?, 7); + assert_eq!(tensor.i((0..2, 0, 0))?.to_vec1::()?, &[0, 12]); + assert_eq!( + tensor.i((0..2, .., 0))?.to_vec2::()?, + &[[0, 4, 8], [12, 16, 20]] + ); + assert_eq!( + tensor.i((..2, .., 3))?.to_vec2::()?, + &[[3, 7, 11], [15, 19, 23]] + ); + assert_eq!(tensor.i((1, .., 3))?.to_vec1::()?, &[15, 19, 23]); + Ok(()) +}