From 98b423145a35306a3d3152ae042ecdbf597a4258 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 13:35:07 +0100 Subject: [PATCH] Bugfix for the contiguous strides. --- src/device.rs | 20 ++++++++++++++++++++ src/dtype.rs | 14 +++++++++----- src/shape.rs | 24 ++++++++++++++++++++++-- src/storage.rs | 1 + tests/tensor_tests.rs | 13 ++++++++++++- 5 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/device.rs b/src/device.rs index 6741e582..0964e83f 100644 --- a/src/device.rs +++ b/src/device.rs @@ -25,6 +25,16 @@ impl NdArray for S { } } +impl NdArray for &[S; N] { + fn shape(&self) -> Result { + Ok(Shape::from(self.len())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage(self.as_slice()) + } +} + impl NdArray for &[S] { fn shape(&self) -> Result { Ok(Shape::from(self.len())) @@ -35,6 +45,16 @@ impl NdArray for &[S] { } } +impl NdArray for &[[S; N]; M] { + fn shape(&self) -> Result { + Ok(Shape::from((M, N))) + } + + fn to_cpu_storage(&self) -> CpuStorage { + S::to_cpu_storage_owned(self.concat()) + } +} + impl Device { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage { match self { diff --git a/src/dtype.rs b/src/dtype.rs index a2c92aa7..d66d046c 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -18,7 +18,11 @@ impl DType { pub trait WithDType: Sized + Copy { const DTYPE: DType; - fn to_cpu_storage(data: &[Self]) -> CpuStorage; + fn to_cpu_storage_owned(data: Vec) -> CpuStorage; + + fn to_cpu_storage(data: &[Self]) -> CpuStorage { + Self::to_cpu_storage_owned(data.to_vec()) + } fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; } @@ -26,8 +30,8 @@ pub trait WithDType: Sized + Copy { impl WithDType for f32 { const DTYPE: DType = DType::F32; - fn to_cpu_storage(data: &[Self]) -> CpuStorage { - CpuStorage::F32(data.to_vec()) + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { + CpuStorage::F32(data) } fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { @@ -44,8 +48,8 @@ impl WithDType for f32 { impl WithDType for f64 { const DTYPE: DType = DType::F64; - fn to_cpu_storage(data: &[Self]) -> CpuStorage { - CpuStorage::F64(data.to_vec()) + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { + CpuStorage::F64(data) } fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { diff --git a/src/shape.rs b/src/shape.rs index 3a23442d..97f0f567 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -109,7 +109,8 @@ impl Shape { /// The strides given in number of elements for a contiguous n-dimensional /// arrays using this shape. pub(crate) fn stride_contiguous(&self) -> Vec { - self.0 + let mut stride: Vec<_> = self + .0 .iter() .rev() .scan(1, |prod, u| { @@ -117,6 +118,25 @@ impl Shape { *prod *= u; Some(prod_pre_mult) }) - .collect() + .collect(); + stride.reverse(); + stride + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn stride() { + let shape = Shape::from(()); + assert_eq!(shape.stride_contiguous(), []); + let shape = Shape::from(42); + assert_eq!(shape.stride_contiguous(), [1]); + let shape = Shape::from((42, 1337)); + assert_eq!(shape.stride_contiguous(), [1337, 1]); + let shape = Shape::from((299, 792, 458)); + assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); } } diff --git a/src/storage.rs b/src/storage.rs index c2f47bea..013397d5 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -17,6 +17,7 @@ impl CpuStorage { } } +#[derive(Debug)] pub(crate) struct StridedIndex<'a> { next_storage_index: Option, multi_index: Vec, diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 98b212b3..bea32336 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -11,7 +11,7 @@ fn zeros() -> Result<()> { #[test] fn add_mul() -> Result<()> { - let tensor = Tensor::new([3f32, 1., 4.].as_slice(), Device::Cpu)?; + let tensor = Tensor::new(&[3f32, 1., 4.], Device::Cpu)?; let dim1 = tensor.shape().r1()?; assert_eq!(dim1, 3); let content: Vec = tensor.to_vec1()?; @@ -24,3 +24,14 @@ fn add_mul() -> Result<()> { assert_eq!(content, [36., 4., 64.]); Ok(()) } + +#[test] +fn tensor_2d() -> Result<()> { + let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, Device::Cpu)?; + let dims = tensor.shape().r2()?; + assert_eq!(dims, (2, 5)); + let content: Vec> = tensor.to_vec2()?; + assert_eq!(content, data); + Ok(()) +}