Proper stride initialization.

This commit is contained in:
laurent
2023-06-20 07:53:53 +01:00
parent bcae61b7f2
commit ef6760117f
2 changed files with 18 additions and 4 deletions

View File

@ -5,6 +5,7 @@ use std::sync::Arc;
pub(crate) struct Tensor_ {
storage: Storage,
shape: Shape,
// The strides are given in number of elements and not in bytes.
stride: Vec<usize>,
op: Option<Op>,
}
@ -21,11 +22,11 @@ impl Tensor {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
let shape = shape.into();
let storage = device.zeros(&shape, dtype);
let rank = shape.rank();
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
stride,
op: None,
};
Self(Arc::new(tensor_))
@ -34,11 +35,11 @@ impl Tensor {
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
let shape = array.shape()?;
let storage = device.tensor(array);
let rank = shape.rank();
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
storage,
shape,
stride: vec![1; rank],
stride,
op: None,
};
Ok(Self(Arc::new(tensor_)))