From ef6760117f62a28c081103417fac7f872656b842 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 07:53:53 +0100 Subject: [PATCH] Proper stride initialization. --- src/shape.rs | 13 +++++++++++++ src/tensor.rs | 9 +++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/shape.rs b/src/shape.rs index b1965872..a5fee614 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -132,4 +132,17 @@ impl Shape { }) } } + + /// The strides given in number of elements for a contiguous n-dimensional + /// arrays using this shape. + pub(crate) fn stride_contiguous(&self) -> Vec { + self.0 + .iter() + .scan(1, |prod, u| { + let prod_pre_mult = *prod; + *prod *= u; + Some(prod_pre_mult) + }) + .collect() + } } diff --git a/src/tensor.rs b/src/tensor.rs index 37f010da..9b1e7d5b 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -5,6 +5,7 @@ use std::sync::Arc; pub(crate) struct Tensor_ { storage: Storage, shape: Shape, + // The strides are given in number of elements and not in bytes. stride: Vec, op: Option, } @@ -21,11 +22,11 @@ impl Tensor { pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { let shape = shape.into(); let storage = device.zeros(&shape, dtype); - let rank = shape.rank(); + let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { storage, shape, - stride: vec![1; rank], + stride, op: None, }; Self(Arc::new(tensor_)) @@ -34,11 +35,11 @@ impl Tensor { pub fn new(array: A, device: Device) -> Result { let shape = array.shape()?; let storage = device.tensor(array); - let rank = shape.rank(); + let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { storage, shape, - stride: vec![1; rank], + stride, op: None, }; Ok(Self(Arc::new(tensor_)))