mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Allow for lazy loading of npz files, use it in llama to reduce memory usage in the cpu version. (#141)
This commit is contained in:
@ -48,7 +48,7 @@ mod indexer;
|
|||||||
mod layout;
|
mod layout;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
mod npy;
|
pub mod npy;
|
||||||
mod op;
|
mod op;
|
||||||
pub mod safetensors;
|
pub mod safetensors;
|
||||||
mod shape;
|
mod shape;
|
||||||
|
@ -251,7 +251,7 @@ impl Tensor {
|
|||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut result = vec![];
|
let mut result = vec![];
|
||||||
for i in 0..zip.len() {
|
for i in 0..zip.len() {
|
||||||
let mut reader = zip.by_index(i).unwrap();
|
let mut reader = zip.by_index(i)?;
|
||||||
let name = {
|
let name = {
|
||||||
let name = reader.name();
|
let name = reader.name();
|
||||||
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
|
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
|
||||||
@ -368,6 +368,53 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Lazy tensor loader.
|
||||||
|
pub struct NpzTensors {
|
||||||
|
index_per_name: HashMap<String, usize>,
|
||||||
|
path: std::path::PathBuf,
|
||||||
|
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||||
|
// re-create a zip reader each time.
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NpzTensors {
|
||||||
|
pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
|
||||||
|
let path = path.as_ref().to_owned();
|
||||||
|
let zip_reader = BufReader::new(File::open(&path)?);
|
||||||
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
|
let mut index_per_name = HashMap::new();
|
||||||
|
for i in 0..zip.len() {
|
||||||
|
let file = zip.by_index(i)?;
|
||||||
|
let name = {
|
||||||
|
let name = file.name();
|
||||||
|
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
|
||||||
|
};
|
||||||
|
index_per_name.insert(name, i);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
index_per_name,
|
||||||
|
path,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||||
|
let index = match self.index_per_name.get(name) {
|
||||||
|
None => return Ok(None),
|
||||||
|
Some(index) => *index,
|
||||||
|
};
|
||||||
|
// We hope that the file has not changed since first reading it.
|
||||||
|
let zip_reader = BufReader::new(File::open(&self.path)?);
|
||||||
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
|
let mut reader = zip.by_index(index)?;
|
||||||
|
let header = read_header(&mut reader)?;
|
||||||
|
let header = Header::parse(&header)?;
|
||||||
|
if header.fortran_order {
|
||||||
|
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||||
|
}
|
||||||
|
let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
|
||||||
|
Ok(Some(tensor))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::Header;
|
use super::Header;
|
||||||
|
@ -145,11 +145,7 @@ fn main() -> Result<()> {
|
|||||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||||
let (llama, tokenizer_filename) = match args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
Some(filename) => {
|
Some(filename) => {
|
||||||
let tensors = Tensor::read_npz(filename)?
|
let vb = VarBuilder::from_npz(filename, DTYPE, &device)?;
|
||||||
.into_iter()
|
|
||||||
.map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?)))
|
|
||||||
.collect::<Result<std::collections::HashMap<String, Tensor>>>()?;
|
|
||||||
let vb = VarBuilder::from_tensors(tensors, DTYPE, &device);
|
|
||||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer)
|
(Llama::load(vb, &cache, &config)?, tokenizer)
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -9,6 +9,7 @@ enum Tensors<'a> {
|
|||||||
routing: HashMap<String, usize>,
|
routing: HashMap<String, usize>,
|
||||||
safetensors: Vec<SafeTensors<'a>>,
|
safetensors: Vec<SafeTensors<'a>>,
|
||||||
},
|
},
|
||||||
|
Npz(candle::npy::NpzTensors),
|
||||||
TensorMap(HashMap<String, Tensor>),
|
TensorMap(HashMap<String, Tensor>),
|
||||||
Zeros,
|
Zeros,
|
||||||
}
|
}
|
||||||
@ -53,6 +54,15 @@ impl<'a> TensorData<'a> {
|
|||||||
dtype,
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn from_npz<P: AsRef<std::path::Path>>(file: P, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let npz = candle::npy::NpzTensors::new(file)?;
|
||||||
|
Ok(Self {
|
||||||
|
tensors: Tensors::Npz(npz),
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -88,6 +98,18 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_npz<P: AsRef<std::path::Path>>(
|
||||||
|
file: P,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let data = TensorData::from_npz(file, dtype, device)?;
|
||||||
|
Ok(Self {
|
||||||
|
data: Arc::new(data),
|
||||||
|
path: vec![],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn push_prefix(&self, s: &str) -> Self {
|
pub fn push_prefix(&self, s: &str) -> Self {
|
||||||
let mut path = self.path.clone();
|
let mut path = self.path.clone();
|
||||||
path.push(s.to_string());
|
path.push(s.to_string());
|
||||||
@ -112,7 +134,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||||
let data = self.data.as_ref();
|
let data = self.data.as_ref();
|
||||||
let s: Shape = s.into();
|
let s: Shape = s.into();
|
||||||
let path = if self.path.is_empty() {
|
let path = if self.path.is_empty() {
|
||||||
@ -128,6 +150,9 @@ impl<'a> VarBuilder<'a> {
|
|||||||
path: path.to_string(),
|
path: path.to_string(),
|
||||||
})?
|
})?
|
||||||
.clone(),
|
.clone(),
|
||||||
|
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor {
|
||||||
|
path: path.to_string(),
|
||||||
|
})?,
|
||||||
Tensors::SafeTensorWithRouting {
|
Tensors::SafeTensorWithRouting {
|
||||||
routing,
|
routing,
|
||||||
safetensors,
|
safetensors,
|
||||||
|
Reference in New Issue
Block a user