mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Proper stride initialization.
This commit is contained in:
13
src/shape.rs
13
src/shape.rs
@ -132,4 +132,17 @@ impl Shape {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The strides given in number of elements for a contiguous n-dimensional
|
||||||
|
/// arrays using this shape.
|
||||||
|
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||||
|
self.0
|
||||||
|
.iter()
|
||||||
|
.scan(1, |prod, u| {
|
||||||
|
let prod_pre_mult = *prod;
|
||||||
|
*prod *= u;
|
||||||
|
Some(prod_pre_mult)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ use std::sync::Arc;
|
|||||||
pub(crate) struct Tensor_ {
|
pub(crate) struct Tensor_ {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
|
// The strides are given in number of elements and not in bytes.
|
||||||
stride: Vec<usize>,
|
stride: Vec<usize>,
|
||||||
op: Option<Op>,
|
op: Option<Op>,
|
||||||
}
|
}
|
||||||
@ -21,11 +22,11 @@ impl Tensor {
|
|||||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.zeros(&shape, dtype);
|
let storage = device.zeros(&shape, dtype);
|
||||||
let rank = shape.rank();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
storage,
|
storage,
|
||||||
shape,
|
shape,
|
||||||
stride: vec![1; rank],
|
stride,
|
||||||
op: None,
|
op: None,
|
||||||
};
|
};
|
||||||
Self(Arc::new(tensor_))
|
Self(Arc::new(tensor_))
|
||||||
@ -34,11 +35,11 @@ impl Tensor {
|
|||||||
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
pub fn new<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
|
||||||
let shape = array.shape()?;
|
let shape = array.shape()?;
|
||||||
let storage = device.tensor(array);
|
let storage = device.tensor(array);
|
||||||
let rank = shape.rank();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
storage,
|
storage,
|
||||||
shape,
|
shape,
|
||||||
stride: vec![1; rank],
|
stride,
|
||||||
op: None,
|
op: None,
|
||||||
};
|
};
|
||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
Reference in New Issue
Block a user