Bugfix for the contiguous strides.

This commit is contained in:
laurent
2023-06-20 13:35:07 +01:00
parent d922ff97f2
commit 98b423145a
5 changed files with 64 additions and 8 deletions

View File

@ -25,6 +25,16 @@ impl<S: crate::WithDType> NdArray for S {
} }
} }
impl<S: crate::WithDType, const N: usize> NdArray for &[S; N] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self.as_slice())
}
}
impl<S: crate::WithDType> NdArray for &[S] { impl<S: crate::WithDType> NdArray for &[S] {
fn shape(&self) -> Result<Shape> { fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len())) Ok(Shape::from(self.len()))
@ -35,6 +45,16 @@ impl<S: crate::WithDType> NdArray for &[S] {
} }
} }
impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((M, N)))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage_owned(self.concat())
}
}
impl Device { impl Device {
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage {
match self { match self {

View File

@ -18,7 +18,11 @@ impl DType {
pub trait WithDType: Sized + Copy { pub trait WithDType: Sized + Copy {
const DTYPE: DType; const DTYPE: DType;
fn to_cpu_storage(data: &[Self]) -> CpuStorage; fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
Self::to_cpu_storage_owned(data.to_vec())
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
} }
@ -26,8 +30,8 @@ pub trait WithDType: Sized + Copy {
impl WithDType for f32 { impl WithDType for f32 {
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
fn to_cpu_storage(data: &[Self]) -> CpuStorage { fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::F32(data.to_vec()) CpuStorage::F32(data)
} }
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
@ -44,8 +48,8 @@ impl WithDType for f32 {
impl WithDType for f64 { impl WithDType for f64 {
const DTYPE: DType = DType::F64; const DTYPE: DType = DType::F64;
fn to_cpu_storage(data: &[Self]) -> CpuStorage { fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::F64(data.to_vec()) CpuStorage::F64(data)
} }
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {

View File

@ -109,7 +109,8 @@ impl Shape {
/// The strides given in number of elements for a contiguous n-dimensional /// The strides given in number of elements for a contiguous n-dimensional
/// arrays using this shape. /// arrays using this shape.
pub(crate) fn stride_contiguous(&self) -> Vec<usize> { pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
self.0 let mut stride: Vec<_> = self
.0
.iter() .iter()
.rev() .rev()
.scan(1, |prod, u| { .scan(1, |prod, u| {
@ -117,6 +118,25 @@ impl Shape {
*prod *= u; *prod *= u;
Some(prod_pre_mult) Some(prod_pre_mult)
}) })
.collect() .collect();
stride.reverse();
stride
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), []);
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
} }
} }

View File

@ -17,6 +17,7 @@ impl CpuStorage {
} }
} }
#[derive(Debug)]
pub(crate) struct StridedIndex<'a> { pub(crate) struct StridedIndex<'a> {
next_storage_index: Option<usize>, next_storage_index: Option<usize>,
multi_index: Vec<usize>, multi_index: Vec<usize>,

View File

@ -11,7 +11,7 @@ fn zeros() -> Result<()> {
#[test] #[test]
fn add_mul() -> Result<()> { fn add_mul() -> Result<()> {
let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?; let tensor = Tensor::new(&[3f32, 1., 4.], Device::Cpu)?;
let dim1 = tensor.shape().r1()?; let dim1 = tensor.shape().r1()?;
assert_eq!(dim1, 3); assert_eq!(dim1, 3);
let content: Vec<f32> = tensor.to_vec1()?; let content: Vec<f32> = tensor.to_vec1()?;
@ -24,3 +24,14 @@ fn add_mul() -> Result<()> {
assert_eq!(content, [36., 4., 64.]); assert_eq!(content, [36., 4., 64.]);
Ok(()) Ok(())
} }
#[test]
fn tensor_2d() -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, Device::Cpu)?;
let dims = tensor.shape().r2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content, data);
Ok(())
}