Bugfix for the contiguous strides.

This commit is contained in:
laurent
2023-06-20 13:35:07 +01:00
parent d922ff97f2
commit 98b423145a
5 changed files with 64 additions and 8 deletions

View File

@ -109,7 +109,8 @@ impl Shape {
/// The strides given in number of elements for a contiguous n-dimensional
/// arrays using this shape.
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
self.0
let mut stride: Vec<_> = self
.0
.iter()
.rev()
.scan(1, |prod, u| {
@ -117,6 +118,25 @@ impl Shape {
*prod *= u;
Some(prod_pre_mult)
})
.collect()
.collect();
stride.reverse();
stride
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), []);
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}