From 2f7a072250625b7311303cbe36bd064685d72d30 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Jun 2023 16:00:22 +0100 Subject: [PATCH] Rename as_slice to storage_data and implement the cuda version. --- src/dtype.rs | 11 +++++++++++ src/tensor.rs | 16 +++++++++++++--- tests/grad_tests.rs | 4 ++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/dtype.rs b/src/dtype.rs index f6249ff2..1b348175 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -26,6 +26,7 @@ pub trait WithDType: Sized + Copy { fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>; + fn cpu_storage_data(s: CpuStorage) -> Result>; } macro_rules! with_dtype { @@ -37,6 +38,16 @@ macro_rules! with_dtype { CpuStorage::$dtype(data) } + fn cpu_storage_data(s: CpuStorage) -> Result> { + match s { + CpuStorage::$dtype(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + }), + } + } + fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> { match s { CpuStorage::$dtype(data) => Ok(data), diff --git a/src/tensor.rs b/src/tensor.rs index 07744d70..72b756a0 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -313,10 +313,20 @@ impl Tensor { crate::StridedIndex::new(self.dims(), self.stride()) } - pub fn as_slice(&self) -> Result<&[S]> { + /// Returns data from the underlying storage, this does not take the strides + /// into account so the size of the resulting buffer might be larger than the + /// tensor number of elements. + pub fn storage_data(&self) -> Result> { match &self.storage { - Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage), - Storage::Cuda { .. } => todo!(), + Storage::Cpu(cpu_storage) => { + let slice = S::cpu_storage_as_slice(cpu_storage)?; + Ok(std::borrow::Cow::Borrowed(slice)) + } + Storage::Cuda(slice) => { + let cpu_storage = slice.to_cpu_storage()?; + let storage_data = S::cpu_storage_data(cpu_storage)?; + Ok(std::borrow::Cow::Owned(storage_data)) + } } } diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs index 77a32dfe..612dffee 100644 --- a/tests/grad_tests.rs +++ b/tests/grad_tests.rs @@ -29,11 +29,11 @@ fn matmul_grad() -> Result<()> { assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3))); assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2))); assert_eq!( - grad_x.as_slice::()?, + &*grad_x.storage_data::()?, &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.] ); assert_eq!( - grad_y.as_slice::()?, + &*grad_y.storage_data::()?, &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.] ); Ok(())