mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Allow for lazy loading of npz files, use it in llama to reduce memory usage in the cpu version. (#141)
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -9,6 +9,7 @@ enum Tensors<'a> {
|
||||
routing: HashMap<String, usize>,
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
},
|
||||
Npz(candle::npy::NpzTensors),
|
||||
TensorMap(HashMap<String, Tensor>),
|
||||
Zeros,
|
||||
}
|
||||
@ -53,6 +54,15 @@ impl<'a> TensorData<'a> {
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_npz<P: AsRef<std::path::Path>>(file: P, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(file)?;
|
||||
Ok(Self {
|
||||
tensors: Tensors::Npz(npz),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -88,6 +98,18 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let data = TensorData::from_npz(file, dtype, device)?;
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
pub fn push_prefix(&self, s: &str) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
@ -112,7 +134,7 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||
let data = self.data.as_ref();
|
||||
let s: Shape = s.into();
|
||||
let path = if self.path.is_empty() {
|
||||
@ -128,6 +150,9 @@ impl<'a> VarBuilder<'a> {
|
||||
path: path.to_string(),
|
||||
})?
|
||||
.clone(),
|
||||
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
})?,
|
||||
Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
|
Reference in New Issue
Block a user