mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Iteration over strided blocks (#175)
* Introduce the strided blocks. * Use the strided blocks to fasten the copy. * Add more testing.
This commit is contained in:
@ -3,28 +3,35 @@ use crate::Layout;
|
||||
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
||||
/// flat buffer using some potential strides.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StridedIndex<'a> {
|
||||
pub struct StridedIndex<'a> {
|
||||
next_storage_index: Option<usize>,
|
||||
multi_index: Vec<usize>,
|
||||
layout: &'a Layout,
|
||||
dims: &'a [usize],
|
||||
stride: &'a [usize],
|
||||
start_offset: usize,
|
||||
}
|
||||
|
||||
impl<'a> StridedIndex<'a> {
|
||||
pub(crate) fn new(layout: &'a Layout) -> Self {
|
||||
let dims = layout.dims();
|
||||
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
|
||||
let elem_count: usize = dims.iter().product();
|
||||
let next_storage_index = if elem_count == 0 {
|
||||
None
|
||||
} else {
|
||||
// This applies to the scalar case.
|
||||
Some(layout.start_offset())
|
||||
Some(start_offset)
|
||||
};
|
||||
StridedIndex {
|
||||
next_storage_index,
|
||||
multi_index: vec![0; dims.len()],
|
||||
layout,
|
||||
dims,
|
||||
stride,
|
||||
start_offset,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_layout(l: &'a Layout) -> Self {
|
||||
Self::new(l.dims(), l.stride(), l.start_offset())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for StridedIndex<'a> {
|
||||
@ -36,12 +43,7 @@ impl<'a> Iterator for StridedIndex<'a> {
|
||||
Some(storage_index) => storage_index,
|
||||
};
|
||||
let mut updated = false;
|
||||
for (multi_i, max_i) in self
|
||||
.multi_index
|
||||
.iter_mut()
|
||||
.zip(self.layout.dims().iter())
|
||||
.rev()
|
||||
{
|
||||
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
|
||||
let next_i = *multi_i + 1;
|
||||
if next_i < *max_i {
|
||||
*multi_i = next_i;
|
||||
@ -55,10 +57,10 @@ impl<'a> Iterator for StridedIndex<'a> {
|
||||
let next_storage_index = self
|
||||
.multi_index
|
||||
.iter()
|
||||
.zip(self.layout.stride().iter())
|
||||
.zip(self.stride.iter())
|
||||
.map(|(&x, &y)| x * y)
|
||||
.sum::<usize>()
|
||||
+ self.layout.start_offset();
|
||||
+ self.start_offset;
|
||||
Some(next_storage_index)
|
||||
} else {
|
||||
None
|
||||
@ -66,3 +68,15 @@ impl<'a> Iterator for StridedIndex<'a> {
|
||||
Some(storage_index)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum StridedBlocks<'a> {
|
||||
SingleBlock {
|
||||
start_offset: usize,
|
||||
len: usize,
|
||||
},
|
||||
MultipleBlocks {
|
||||
block_start_index: StridedIndex<'a>,
|
||||
block_len: usize,
|
||||
},
|
||||
}
|
||||
|
Reference in New Issue
Block a user