mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
83 lines
2.2 KiB
Rust
83 lines
2.2 KiB
Rust
use crate::Layout;
|
|
|
|
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
|
/// flat buffer using some potential strides.
|
|
#[derive(Debug)]
|
|
pub struct StridedIndex<'a> {
|
|
next_storage_index: Option<usize>,
|
|
multi_index: Vec<usize>,
|
|
dims: &'a [usize],
|
|
stride: &'a [usize],
|
|
}
|
|
|
|
impl<'a> StridedIndex<'a> {
|
|
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
|
|
let elem_count: usize = dims.iter().product();
|
|
let next_storage_index = if elem_count == 0 {
|
|
None
|
|
} else {
|
|
// This applies to the scalar case.
|
|
Some(start_offset)
|
|
};
|
|
StridedIndex {
|
|
next_storage_index,
|
|
multi_index: vec![0; dims.len()],
|
|
dims,
|
|
stride,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn from_layout(l: &'a Layout) -> Self {
|
|
Self::new(l.dims(), l.stride(), l.start_offset())
|
|
}
|
|
}
|
|
|
|
impl<'a> Iterator for StridedIndex<'a> {
|
|
type Item = usize;
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
let storage_index = match self.next_storage_index {
|
|
None => return None,
|
|
Some(storage_index) => storage_index,
|
|
};
|
|
let mut updated = false;
|
|
let mut next_storage_index = storage_index;
|
|
for ((multi_i, max_i), stride_i) in self
|
|
.multi_index
|
|
.iter_mut()
|
|
.zip(self.dims.iter())
|
|
.zip(self.stride.iter())
|
|
.rev()
|
|
{
|
|
let next_i = *multi_i + 1;
|
|
if next_i < *max_i {
|
|
*multi_i = next_i;
|
|
updated = true;
|
|
next_storage_index += stride_i;
|
|
break;
|
|
} else {
|
|
next_storage_index -= *multi_i * stride_i;
|
|
*multi_i = 0
|
|
}
|
|
}
|
|
self.next_storage_index = if updated {
|
|
Some(next_storage_index)
|
|
} else {
|
|
None
|
|
};
|
|
Some(storage_index)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum StridedBlocks<'a> {
|
|
SingleBlock {
|
|
start_offset: usize,
|
|
len: usize,
|
|
},
|
|
MultipleBlocks {
|
|
block_start_index: StridedIndex<'a>,
|
|
block_len: usize,
|
|
},
|
|
}
|