Rename as_slice to storage_data and implement the cuda version.

This commit is contained in:
laurent
2023-06-22 16:00:22 +01:00
parent 065b7a19c7
commit 2f7a072250
3 changed files with 26 additions and 5 deletions

View File

@ -26,6 +26,7 @@ pub trait WithDType: Sized + Copy {
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
}
macro_rules! with_dtype {
@ -37,6 +38,16 @@ macro_rules! with_dtype {
CpuStorage::$dtype(data)
}
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
}),
}
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
match s {
CpuStorage::$dtype(data) => Ok(data),

View File

@ -313,10 +313,20 @@ impl Tensor {
crate::StridedIndex::new(self.dims(), self.stride())
}
pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
/// Returns data from the underlying storage, this does not take the strides
/// into account so the size of the resulting buffer might be larger than the
/// tensor number of elements.
pub fn storage_data<S: crate::WithDType>(&self) -> Result<std::borrow::Cow<[S]>> {
match &self.storage {
Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
Storage::Cuda { .. } => todo!(),
Storage::Cpu(cpu_storage) => {
let slice = S::cpu_storage_as_slice(cpu_storage)?;
Ok(std::borrow::Cow::Borrowed(slice))
}
Storage::Cuda(slice) => {
let cpu_storage = slice.to_cpu_storage()?;
let storage_data = S::cpu_storage_data(cpu_storage)?;
Ok(std::borrow::Cow::Owned(storage_data))
}
}
}

View File

@ -29,11 +29,11 @@ fn matmul_grad() -> Result<()> {
assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3)));
assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2)));
assert_eq!(
grad_x.as_slice::<f32>()?,
&*grad_x.storage_data::<f32>()?,
&[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.]
);
assert_eq!(
grad_y.as_slice::<f32>()?,
&*grad_y.storage_data::<f32>()?,
&[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.]
);
Ok(())