Proper stride initialization.

This commit is contained in:
laurent
2023-06-20 07:53:53 +01:00
parent bcae61b7f2
commit ef6760117f
2 changed files with 18 additions and 4 deletions

View File

@ -132,4 +132,17 @@ impl Shape {
})
}
}
/// The strides given in number of elements for a contiguous n-dimensional
/// arrays using this shape.
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
self.0
.iter()
.scan(1, |prod, u| {
let prod_pre_mult = *prod;
*prod *= u;
Some(prod_pre_mult)
})
.collect()
}
}

View File

@ -5,6 +5,7 @@ use std::sync::Arc;
pub(crate) struct Tensor_ {
storage: Storage,
shape: Shape,
// The strides are given in number of elements and not in bytes.
stride: Vec<usize>,
op: Option<Op>,
}
@ -21,11 +22,11 @@ impl Tensor {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
let shape = shape.into();
let storage = device.zeros(&shape, dtype);
let rank = shape.rank();
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
stride,
op: None,
};
Self(Arc::new(tensor_))
@ -34,11 +35,11 @@ impl Tensor {
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
let shape = array.shape()?;
let storage = device.tensor(array);
let rank = shape.rank();
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
stride,
op: None,
};
Ok(Self(Arc::new(tensor_)))