Proper stride initialization.

This commit is contained in:
laurent
2023-06-20 07:53:53 +01:00
parent bcae61b7f2
commit ef6760117f
2 changed files with 18 additions and 4 deletions

View File

@ -132,4 +132,17 @@ impl Shape {
}) })
} }
} }
/// The strides given in number of elements for a contiguous n-dimensional
/// arrays using this shape.
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
self.0
.iter()
.scan(1, |prod, u| {
let prod_pre_mult = *prod;
*prod *= u;
Some(prod_pre_mult)
})
.collect()
}
} }

View File

@ -5,6 +5,7 @@ use std::sync::Arc;
pub(crate) struct Tensor_ { pub(crate) struct Tensor_ {
storage: Storage, storage: Storage,
shape: Shape, shape: Shape,
// The strides are given in number of elements and not in bytes.
stride: Vec<usize>, stride: Vec<usize>,
op: Option<Op>, op: Option<Op>,
} }
@ -21,11 +22,11 @@ impl Tensor {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self { pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
let shape = shape.into(); let shape = shape.into();
let storage = device.zeros(&shape, dtype); let storage = device.zeros(&shape, dtype);
let rank = shape.rank(); let stride = shape.stride_contiguous();
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
storage, storage,
shape, shape,
stride: vec![1; rank], stride,
op: None, op: None,
}; };
Self(Arc::new(tensor_)) Self(Arc::new(tensor_))
@ -34,11 +35,11 @@ impl Tensor {
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> { pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
let shape = array.shape()?; let shape = array.shape()?;
let storage = device.tensor(array); let storage = device.tensor(array);
let rank = shape.rank(); let stride = shape.stride_contiguous();
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
storage, storage,
shape, shape,
stride: vec![1; rank], stride,
op: None, op: None,
}; };
Ok(Self(Arc::new(tensor_))) Ok(Self(Arc::new(tensor_)))