Iteration over strided blocks (#175)

* Introduce the strided blocks.

* Use the strided blocks to fasten the copy.

* Add more testing.
This commit is contained in:
Laurent Mazare
2023-07-15 21:30:35 +01:00
committed by GitHub
parent ad91415b4f
commit 18ea92d83b
6 changed files with 236 additions and 23 deletions

View File

@ -190,13 +190,17 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
dst_offset: usize,
src_l: &Layout,
) {
match src_l.contiguous_offsets() {
Some((o_dst1, o_dst2)) => {
let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1);
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2])
match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let to_copy = (dst.len() - dst_offset).min(len);
dst[dst_offset..dst_offset + to_copy]
.copy_from_slice(&src[start_offset..start_offset + to_copy])
}
None => {
for (dst_index, src_index) in src_l.strided_index().enumerate() {
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len: 1,
} => {
for (dst_index, src_index) in block_start_index.enumerate() {
let dst_index = dst_index + dst_offset;
if dst_index >= dst.len() {
break;
@ -204,6 +208,22 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
dst[dst_index] = src[src_index]
}
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut dst_index = dst_offset;
for src_index in block_start_index {
let next_dst_index = dst_index + block_len;
if dst_index >= dst.len() {
break;
}
let to_copy = usize::min(block_len, dst.len() - dst_index);
dst[dst_index..dst_index + to_copy]
.copy_from_slice(&src[src_index..src_index + to_copy]);
dst_index = next_dst_index
}
}
}
}

View File

@ -51,6 +51,8 @@ impl Layout {
}
/// Returns true if the data is stored in a C contiguous (aka row major) way.
/// Note that this does not implies that the start offset is 0 or that there are no extra
/// elements at the end of the storage.
pub fn is_contiguous(&self) -> bool {
self.shape.is_contiguous(&self.stride)
}
@ -146,6 +148,35 @@ impl Layout {
}
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(self)
crate::StridedIndex::from_layout(self)
}
pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
let mut block_len = 1;
let mut contiguous_dims = 0; // These are counted from the right.
for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
if stride != block_len {
break;
}
block_len *= dim;
contiguous_dims += 1;
}
let index_dims = self.dims().len() - contiguous_dims;
if index_dims == 0 {
crate::StridedBlocks::SingleBlock {
start_offset: self.start_offset,
len: block_len,
}
} else {
let block_start_index = crate::StridedIndex::new(
&self.dims()[..index_dims],
&self.stride[..index_dims],
self.start_offset,
);
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
}
}
}
}

View File

@ -67,7 +67,7 @@ pub use indexer::IndexOp;
pub use layout::Layout;
pub use shape::{Shape, D};
pub use storage::Storage;
use strided_index::StridedIndex;
pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;

View File

@ -3,28 +3,35 @@ use crate::Layout;
/// An iterator over offset position for items of an N-dimensional arrays stored in a
/// flat buffer using some potential strides.
#[derive(Debug)]
pub(crate) struct StridedIndex<'a> {
pub struct StridedIndex<'a> {
next_storage_index: Option<usize>,
multi_index: Vec<usize>,
layout: &'a Layout,
dims: &'a [usize],
stride: &'a [usize],
start_offset: usize,
}
impl<'a> StridedIndex<'a> {
pub(crate) fn new(layout: &'a Layout) -> Self {
let dims = layout.dims();
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
let elem_count: usize = dims.iter().product();
let next_storage_index = if elem_count == 0 {
None
} else {
// This applies to the scalar case.
Some(layout.start_offset())
Some(start_offset)
};
StridedIndex {
next_storage_index,
multi_index: vec![0; dims.len()],
layout,
dims,
stride,
start_offset,
}
}
pub(crate) fn from_layout(l: &'a Layout) -> Self {
Self::new(l.dims(), l.stride(), l.start_offset())
}
}
impl<'a> Iterator for StridedIndex<'a> {
@ -36,12 +43,7 @@ impl<'a> Iterator for StridedIndex<'a> {
Some(storage_index) => storage_index,
};
let mut updated = false;
for (multi_i, max_i) in self
.multi_index
.iter_mut()
.zip(self.layout.dims().iter())
.rev()
{
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
@ -55,10 +57,10 @@ impl<'a> Iterator for StridedIndex<'a> {
let next_storage_index = self
.multi_index
.iter()
.zip(self.layout.stride().iter())
.zip(self.stride.iter())
.map(|(&x, &y)| x * y)
.sum::<usize>()
+ self.layout.start_offset();
+ self.start_offset;
Some(next_storage_index)
} else {
None
@ -66,3 +68,15 @@ impl<'a> Iterator for StridedIndex<'a> {
Some(storage_index)
}
}
#[derive(Debug)]
pub enum StridedBlocks<'a> {
SingleBlock {
start_offset: usize,
len: usize,
},
MultipleBlocks {
block_start_index: StridedIndex<'a>,
block_len: usize,
},
}

View File

@ -834,10 +834,20 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
/// Returns an iterator over position of the elements in the storage when ranging over the
/// index tuples in lexicographic order.
pub fn strided_index(&self) -> crate::StridedIndex {
self.layout.strided_index()
}
/// Similar to `strided_index` but returns the position of the start of each contiguous block
/// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
/// will only return the start offset and the size would be the number of elements in the
/// tensor.
pub fn strided_blocks(&self) -> crate::StridedBlocks {
self.layout.strided_blocks()
}
/// Returns the data contained in a 1D tensor as a vector of scalar values.
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {