mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Avoid copying the data around when loading weights.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage};
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
/// can live on the same location (typically for cuda devices).
|
||||
@ -21,7 +21,7 @@ pub trait NdArray {
|
||||
fn to_cpu_storage(&self) -> CpuStorage;
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType> NdArray for S {
|
||||
impl<S: WithDType> NdArray for S {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(()))
|
||||
}
|
||||
@ -31,7 +31,7 @@ impl<S: crate::WithDType> NdArray for S {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType, const N: usize> NdArray for &[S; N] {
|
||||
impl<S: WithDType, const N: usize> NdArray for &[S; N] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
@ -41,7 +41,7 @@ impl<S: crate::WithDType, const N: usize> NdArray for &[S; N] {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType> NdArray for &[S] {
|
||||
impl<S: WithDType> NdArray for &[S] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
@ -51,7 +51,7 @@ impl<S: crate::WithDType> NdArray for &[S] {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
|
||||
impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((M, N)))
|
||||
}
|
||||
@ -61,7 +61,7 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N];
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: crate::WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
for &[[[S; N3]; N2]; N1]
|
||||
{
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
@ -138,4 +138,15 @@ impl Device {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||
Device::Cuda(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -195,20 +195,19 @@ impl Tensor {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::F32 => {
|
||||
// TODO: Avoid the data being copied around multiple times.
|
||||
let mut data_t = vec![0f32; elem_count];
|
||||
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_slice(&data_t, shape, &Device::Cpu)
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data_t = vec![0f64; elem_count];
|
||||
reader.read_f64_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_slice(&data_t, shape, &Device::Cpu)
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::U32 => {
|
||||
let mut data_t = vec![0u32; elem_count];
|
||||
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_slice(&data_t, shape, &Device::Cpu)
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -199,6 +199,37 @@ impl Tensor {
|
||||
Self::new_impl(array, shape, device, true)
|
||||
}
|
||||
|
||||
pub fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let buffer_size = data.len();
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape });
|
||||
}
|
||||
let storage = device.storage_owned(data)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
}
|
||||
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::from_vec_impl(data, shape, device, false)
|
||||
}
|
||||
|
||||
pub fn var_from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::from_vec_impl(data, shape, device, true)
|
||||
}
|
||||
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
|
Reference in New Issue
Block a user