diff --git a/src/device.rs b/src/device.rs index e83c844f..1b56c178 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, DType, Result, Shape, Storage}; +use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; /// A `DeviceLocation` represents a physical device whereas multiple `Device` /// can live on the same location (typically for cuda devices). @@ -21,7 +21,7 @@ pub trait NdArray { fn to_cpu_storage(&self) -> CpuStorage; } -impl NdArray for S { +impl NdArray for S { fn shape(&self) -> Result { Ok(Shape::from(())) } @@ -31,7 +31,7 @@ impl NdArray for S { } } -impl NdArray for &[S; N] { +impl NdArray for &[S; N] { fn shape(&self) -> Result { Ok(Shape::from(self.len())) } @@ -41,7 +41,7 @@ impl NdArray for &[S; N] { } } -impl NdArray for &[S] { +impl NdArray for &[S] { fn shape(&self) -> Result { Ok(Shape::from(self.len())) } @@ -51,7 +51,7 @@ impl NdArray for &[S] { } } -impl NdArray for &[[S; N]; M] { +impl NdArray for &[[S; N]; M] { fn shape(&self) -> Result { Ok(Shape::from((M, N))) } @@ -61,7 +61,7 @@ impl NdArray for &[[S; N]; } } -impl NdArray +impl NdArray for &[[[S; N3]; N2]; N1] { fn shape(&self) -> Result { @@ -138,4 +138,15 @@ impl Device { } } } + + pub(crate) fn storage_owned(&self, data: Vec) -> Result { + match self { + Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), + Device::Cuda(device) => { + let storage = S::to_cpu_storage_owned(data); + let storage = device.cuda_from_cpu_storage(&storage)?; + Ok(Storage::Cuda(storage)) + } + } + } } diff --git a/src/npy.rs b/src/npy.rs index 81aa565b..43a6cb1c 100644 --- a/src/npy.rs +++ b/src/npy.rs @@ -195,20 +195,19 @@ impl Tensor { let elem_count = shape.elem_count(); match dtype { DType::F32 => { - // TODO: Avoid the data being copied around multiple times. let mut data_t = vec![0f32; elem_count]; reader.read_f32_into::(&mut data_t)?; - Tensor::from_slice(&data_t, shape, &Device::Cpu) + Tensor::from_vec(data_t, shape, &Device::Cpu) } DType::F64 => { let mut data_t = vec![0f64; elem_count]; reader.read_f64_into::(&mut data_t)?; - Tensor::from_slice(&data_t, shape, &Device::Cpu) + Tensor::from_vec(data_t, shape, &Device::Cpu) } DType::U32 => { let mut data_t = vec![0u32; elem_count]; reader.read_u32_into::(&mut data_t)?; - Tensor::from_slice(&data_t, shape, &Device::Cpu) + Tensor::from_vec(data_t, shape, &Device::Cpu) } } } diff --git a/src/tensor.rs b/src/tensor.rs index 2cf24a06..57695819 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -199,6 +199,37 @@ impl Tensor { Self::new_impl(array, shape, device, true) } + pub fn from_vec_impl, D: crate::WithDType>( + data: Vec, + shape: S, + device: &Device, + is_variable: bool, + ) -> Result { + let shape = shape.into(); + let buffer_size = data.len(); + if buffer_size != shape.elem_count() { + return Err(Error::ShapeMismatch { buffer_size, shape }); + } + let storage = device.storage_owned(data)?; + Ok(from_storage(storage, shape, None, is_variable)) + } + + pub fn from_vec, D: crate::WithDType>( + data: Vec, + shape: S, + device: &Device, + ) -> Result { + Self::from_vec_impl(data, shape, device, false) + } + + pub fn var_from_vec, D: crate::WithDType>( + data: Vec, + shape: S, + device: &Device, + ) -> Result { + Self::from_vec_impl(data, shape, device, true) + } + pub fn from_slice, D: crate::WithDType>( array: &[D], shape: S,