Avoid copying the data around when loading weights.

This commit is contained in:
laurent
2023-06-26 08:09:03 +01:00
parent 4ad5d17d8c
commit 512d12e38d
3 changed files with 51 additions and 10 deletions

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, DType, Result, Shape, Storage};
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices).
@ -21,7 +21,7 @@ pub trait NdArray {
fn to_cpu_storage(&self) -> CpuStorage;
}
impl<S: crate::WithDType> NdArray for S {
impl<S: WithDType> NdArray for S {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(()))
}
@ -31,7 +31,7 @@ impl<S: crate::WithDType> NdArray for S {
}
}
impl<S: crate::WithDType, const N: usize> NdArray for &[S; N] {
impl<S: WithDType, const N: usize> NdArray for &[S; N] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
@ -41,7 +41,7 @@ impl<S: crate::WithDType, const N: usize> NdArray for &[S; N] {
}
}
impl<S: crate::WithDType> NdArray for &[S] {
impl<S: WithDType> NdArray for &[S] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
@ -51,7 +51,7 @@ impl<S: crate::WithDType> NdArray for &[S] {
}
}
impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((M, N)))
}
@ -61,7 +61,7 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
}
}
impl<S: crate::WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
for &[[[S; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
@ -138,4 +138,15 @@ impl Device {
}
}
}
pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.cuda_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
}
}

View File

@ -195,20 +195,19 @@ impl Tensor {
let elem_count = shape.elem_count();
match dtype {
DType::F32 => {
// TODO: Avoid the data being copied around multiple times.
let mut data_t = vec![0f32; elem_count];
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_slice(&data_t, shape, &Device::Cpu)
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F64 => {
let mut data_t = vec![0f64; elem_count];
reader.read_f64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_slice(&data_t, shape, &Device::Cpu)
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::U32 => {
let mut data_t = vec![0u32; elem_count];
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_slice(&data_t, shape, &Device::Cpu)
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
}
}

View File

@ -199,6 +199,37 @@ impl Tensor {
Self::new_impl(array, shape, device, true)
}
pub fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let buffer_size = data.len();
if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape });
}
let storage = device.storage_owned(data)?;
Ok(from_storage(storage, shape, None, is_variable))
}
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(data, shape, device, false)
}
pub fn var_from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(data, shape, device, true)
}
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,