Move the StridedIndex in its own module.

This commit is contained in:
laurent
2023-06-21 07:44:36 +01:00
parent 23db8a7da8
commit 3a5405ca6d
4 changed files with 66 additions and 63 deletions

View File

@ -4,6 +4,7 @@ mod error;
mod op;
mod shape;
mod storage;
mod strided_index;
mod tensor;
pub use device::Device;
@ -11,4 +12,5 @@ pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use shape::Shape;
pub use storage::{CpuStorage, Storage};
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};

View File

@ -1,4 +1,4 @@
use crate::{DType, Device, Error, Result, Shape};
use crate::{DType, Device, Error, Result, Shape, StridedIndex};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@ -17,66 +17,6 @@ impl CpuStorage {
}
}
#[derive(Debug)]
pub(crate) struct StridedIndex<'a> {
next_storage_index: Option<usize>,
multi_index: Vec<usize>,
dims: &'a [usize],
stride: &'a [usize],
}
impl<'a> StridedIndex<'a> {
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
let elem_count: usize = dims.iter().product();
let next_storage_index = if elem_count == 0 {
None
} else {
// This applies to the scalar case.
Some(0)
};
StridedIndex {
next_storage_index,
multi_index: vec![0; dims.len()],
dims,
stride,
}
}
}
impl<'a> Iterator for StridedIndex<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_storage_index {
None => return None,
Some(storage_index) => storage_index,
};
let mut updated = false;
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
break;
} else {
*multi_i = 0
}
}
self.next_storage_index = if updated {
let next_storage_index = self
.multi_index
.iter()
.zip(self.stride.iter())
.map(|(&x, &y)| x * y)
.sum();
Some(next_storage_index)
} else {
None
};
Some(storage_index)
}
}
#[derive(Debug, Clone)]
pub enum Storage {
Cpu(CpuStorage),

61
src/strided_index.rs Normal file
View File

@ -0,0 +1,61 @@
/// An iterator over offset position for items of an N-dimensional arrays stored in a
/// flat buffer using some potential strides.
#[derive(Debug)]
pub(crate) struct StridedIndex<'a> {
next_storage_index: Option<usize>,
multi_index: Vec<usize>,
dims: &'a [usize],
stride: &'a [usize],
}
impl<'a> StridedIndex<'a> {
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
let elem_count: usize = dims.iter().product();
let next_storage_index = if elem_count == 0 {
None
} else {
// This applies to the scalar case.
Some(0)
};
StridedIndex {
next_storage_index,
multi_index: vec![0; dims.len()],
dims,
stride,
}
}
}
impl<'a> Iterator for StridedIndex<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_storage_index {
None => return None,
Some(storage_index) => storage_index,
};
let mut updated = false;
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
break;
} else {
*multi_i = 0
}
}
self.next_storage_index = if updated {
let next_storage_index = self
.multi_index
.iter()
.zip(self.stride.iter())
.map(|(&x, &y)| x * y)
.sum();
Some(next_storage_index)
} else {
None
};
Some(storage_index)
}
}

View File

@ -232,8 +232,8 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex {
crate::storage::StridedIndex::new(self.dims(), self.stride())
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(self.dims(), self.stride())
}
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {