diff --git a/src/lib.rs b/src/lib.rs index 56d13517..58c2ba52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ mod error; mod op; mod shape; mod storage; +mod strided_index; mod tensor; pub use device::Device; @@ -11,4 +12,5 @@ pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; pub use shape::Shape; pub use storage::{CpuStorage, Storage}; +use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; diff --git a/src/storage.rs b/src/storage.rs index 65f7c549..463788d4 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,4 +1,4 @@ -use crate::{DType, Device, Error, Result, Shape}; +use crate::{DType, Device, Error, Result, Shape, StridedIndex}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. @@ -17,66 +17,6 @@ impl CpuStorage { } } -#[derive(Debug)] -pub(crate) struct StridedIndex<'a> { - next_storage_index: Option, - multi_index: Vec, - dims: &'a [usize], - stride: &'a [usize], -} - -impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { - let elem_count: usize = dims.iter().product(); - let next_storage_index = if elem_count == 0 { - None - } else { - // This applies to the scalar case. - Some(0) - }; - StridedIndex { - next_storage_index, - multi_index: vec![0; dims.len()], - dims, - stride, - } - } -} - -impl<'a> Iterator for StridedIndex<'a> { - type Item = usize; - - fn next(&mut self) -> Option { - let storage_index = match self.next_storage_index { - None => return None, - Some(storage_index) => storage_index, - }; - let mut updated = false; - for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { - let next_i = *multi_i + 1; - if next_i < *max_i { - *multi_i = next_i; - updated = true; - break; - } else { - *multi_i = 0 - } - } - self.next_storage_index = if updated { - let next_storage_index = self - .multi_index - .iter() - .zip(self.stride.iter()) - .map(|(&x, &y)| x * y) - .sum(); - Some(next_storage_index) - } else { - None - }; - Some(storage_index) - } -} - #[derive(Debug, Clone)] pub enum Storage { Cpu(CpuStorage), diff --git a/src/strided_index.rs b/src/strided_index.rs new file mode 100644 index 00000000..2a23e9ec --- /dev/null +++ b/src/strided_index.rs @@ -0,0 +1,61 @@ +/// An iterator over offset position for items of an N-dimensional arrays stored in a +/// flat buffer using some potential strides. +#[derive(Debug)] +pub(crate) struct StridedIndex<'a> { + next_storage_index: Option, + multi_index: Vec, + dims: &'a [usize], + stride: &'a [usize], +} + +impl<'a> StridedIndex<'a> { + pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + let elem_count: usize = dims.iter().product(); + let next_storage_index = if elem_count == 0 { + None + } else { + // This applies to the scalar case. + Some(0) + }; + StridedIndex { + next_storage_index, + multi_index: vec![0; dims.len()], + dims, + stride, + } + } +} + +impl<'a> Iterator for StridedIndex<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + let storage_index = match self.next_storage_index { + None => return None, + Some(storage_index) => storage_index, + }; + let mut updated = false; + for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + let next_i = *multi_i + 1; + if next_i < *max_i { + *multi_i = next_i; + updated = true; + break; + } else { + *multi_i = 0 + } + } + self.next_storage_index = if updated { + let next_storage_index = self + .multi_index + .iter() + .zip(self.stride.iter()) + .map(|(&x, &y)| x * y) + .sum(); + Some(next_storage_index) + } else { + None + }; + Some(storage_index) + } +} diff --git a/src/tensor.rs b/src/tensor.rs index bfe01adf..ff6cb3dc 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -232,8 +232,8 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } - pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex { - crate::storage::StridedIndex::new(self.dims(), self.stride()) + pub(crate) fn strided_index(&self) -> crate::StridedIndex { + crate::StridedIndex::new(self.dims(), self.stride()) } pub fn to_vec1(&self) -> Result> {