From 6c5fc767a8d0804c9ba8000cbf4877975536cf6a Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 10:50:58 +0100 Subject: [PATCH] Add the slice indexing. --- src/storage.rs | 62 ++++++++++++++++++++++++++++++++++++++++++- src/tensor.rs | 21 ++++++++++++--- tests/tensor_tests.rs | 2 ++ 3 files changed, 81 insertions(+), 4 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index 859b3f76..699cead9 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -17,7 +17,65 @@ impl CpuStorage { } } -#[allow(dead_code)] +pub(crate) struct StridedIndex<'a> { + next_storage_index: Option, + multi_index: Vec, + dims: &'a [usize], + stride: &'a [usize], +} + +impl<'a> StridedIndex<'a> { + pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + let elem_count: usize = dims.iter().product(); + let next_storage_index = if elem_count == 0 { + None + } else { + // This applies to the scalar case. + Some(0) + }; + StridedIndex { + next_storage_index, + multi_index: vec![0; dims.len()], + dims, + stride, + } + } +} + +impl<'a> Iterator for StridedIndex<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let storage_index = match self.next_storage_index { + None => return None, + Some(storage_index) => storage_index, + }; + let mut updated = false; + for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + let next_i = *multi_i + 1; + if next_i < *max_i { + *multi_i = next_i; + updated = true; + break; + } else { + *multi_i = 0 + } + } + self.next_storage_index = if updated { + let next_storage_index = self + .multi_index + .iter() + .zip(self.stride.iter()) + .map(|(&x, &y)| x * y) + .sum(); + Some(next_storage_index) + } else { + None + }; + Some(storage_index) + } +} + #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), @@ -56,6 +114,7 @@ impl Storage { } } + // TODO: Support broadcasting? pub(crate) fn add_impl( &self, rhs: &Self, @@ -95,6 +154,7 @@ impl Storage { } } + // TODO: Support broadcasting? pub(crate) fn mul_impl( &self, rhs: &Self, diff --git a/src/tensor.rs b/src/tensor.rs index 8881cad0..97573158 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -114,13 +114,28 @@ impl Tensor { } } + pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex { + crate::storage::StridedIndex::new(self.dims(), self.stride()) + } + pub fn to_vec1(&self) -> Result> { - // TODO: properly use the strides here. - todo!() + if self.rank() != 1 { + return Err(Error::UnexpectedNumberOfDims { + expected: 1, + got: self.rank(), + shape: self.shape().clone(), + }); + } + match &self.storage { + Storage::Cpu(cpu_storage) => { + let data = S::cpu_storage_as_slice(cpu_storage)?; + Ok(self.strided_index().map(|i| data[i]).collect()) + } + } } pub fn to_vec2(&self) -> Result>> { - // TODO: properly use the strides here. + // TODO: Similar to to_vec1 then reshape the resulting vec? todo!() } diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index a0f4630d..5cde2db5 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -9,5 +9,7 @@ fn add() -> Result<()> { let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?; let dim1 = tensor.shape().r1()?; assert_eq!(dim1, 3); + let content: Vec = tensor.to_vec1()?; + assert_eq!(content, [3., 1., 4.]); Ok(()) }