Broadcasting performance optimization (cpu) (#182)

* Avoid recomputing the index from scratch each time.

* More performance optimisations.
This commit is contained in:
Laurent Mazare
2023-07-17 13:41:09 +01:00
committed by GitHub
parent 5b1c0bc9be
commit acb2f90469
4 changed files with 125 additions and 11 deletions

View File

@ -8,7 +8,6 @@ pub struct StridedIndex<'a> {
multi_index: Vec<usize>,
dims: &'a [usize],
stride: &'a [usize],
start_offset: usize,
}
impl<'a> StridedIndex<'a> {
@ -25,7 +24,6 @@ impl<'a> StridedIndex<'a> {
multi_index: vec![0; dims.len()],
dims,
stride,
start_offset,
}
}
@ -43,24 +41,26 @@ impl<'a> Iterator for StridedIndex<'a> {
Some(storage_index) => storage_index,
};
let mut updated = false;
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
let mut next_storage_index = storage_index;
for ((multi_i, max_i), stride_i) in self
.multi_index
.iter_mut()
.zip(self.dims.iter())
.zip(self.stride.iter())
.rev()
{
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
next_storage_index += stride_i;
break;
} else {
next_storage_index -= *multi_i * stride_i;
*multi_i = 0
}
}
self.next_storage_index = if updated {
let next_storage_index = self
.multi_index
.iter()
.zip(self.stride.iter())
.map(|(&x, &y)| x * y)
.sum::<usize>()
+ self.start_offset;
Some(next_storage_index)
} else {
None