mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Bugfix for the contiguous strides.
This commit is contained in:
24
src/shape.rs
24
src/shape.rs
@ -109,7 +109,8 @@ impl Shape {
|
||||
/// The strides given in number of elements for a contiguous n-dimensional
|
||||
/// arrays using this shape.
|
||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||
self.0
|
||||
let mut stride: Vec<_> = self
|
||||
.0
|
||||
.iter()
|
||||
.rev()
|
||||
.scan(1, |prod, u| {
|
||||
@ -117,6 +118,25 @@ impl Shape {
|
||||
*prod *= u;
|
||||
Some(prod_pre_mult)
|
||||
})
|
||||
.collect()
|
||||
.collect();
|
||||
stride.reverse();
|
||||
stride
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), []);
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user