From fa760759e5fa94c8486566af6dd3a456d0548221 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 11 Jul 2023 20:22:34 +0100 Subject: [PATCH] Allow for lazy loading of npz files, use it in llama to reduce memory usage in the cpu version. (#141) --- candle-core/src/lib.rs | 2 +- candle-core/src/npy.rs | 49 +++++++++++++++++++++++++- candle-examples/examples/llama/main.rs | 6 +--- candle-nn/src/var_builder.rs | 29 +++++++++++++-- 4 files changed, 77 insertions(+), 9 deletions(-) diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 06fc87d1..a51c8e29 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -48,7 +48,7 @@ mod indexer; mod layout; #[cfg(feature = "mkl")] mod mkl; -mod npy; +pub mod npy; mod op; pub mod safetensors; mod shape; diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 7cf6d381..6302cf71 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -251,7 +251,7 @@ impl Tensor { let mut zip = zip::ZipArchive::new(zip_reader)?; let mut result = vec![]; for i in 0..zip.len() { - let mut reader = zip.by_index(i).unwrap(); + let mut reader = zip.by_index(i)?; let name = { let name = reader.name(); name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned() @@ -368,6 +368,53 @@ impl Tensor { } } +/// Lazy tensor loader. +pub struct NpzTensors { + index_per_name: HashMap, + path: std::path::PathBuf, + // We do not store a zip reader as it needs mutable access to extract data. Instead we + // re-create a zip reader each time. +} + +impl NpzTensors { + pub fn new>(path: T) -> Result { + let path = path.as_ref().to_owned(); + let zip_reader = BufReader::new(File::open(&path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut index_per_name = HashMap::new(); + for i in 0..zip.len() { + let file = zip.by_index(i)?; + let name = { + let name = file.name(); + name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned() + }; + index_per_name.insert(name, i); + } + Ok(Self { + index_per_name, + path, + }) + } + + pub fn get(&self, name: &str) -> Result> { + let index = match self.index_per_name.get(name) { + None => return Ok(None), + Some(index) => *index, + }; + // We hope that the file has not changed since first reading it. + let zip_reader = BufReader::new(File::open(&self.path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut reader = zip.by_index(index)?; + let header = read_header(&mut reader)?; + let header = Header::parse(&header)?; + if header.fortran_order { + return Err(Error::Npy("fortran order not supported".to_string())); + } + let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?; + Ok(Some(tensor)) + } +} + #[cfg(test)] mod tests { use super::Header; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 6ac4458e..aeee6867 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -145,11 +145,7 @@ fn main() -> Result<()> { let cache = model::Cache::new(!args.no_kv_cache, &config, &device); let (llama, tokenizer_filename) = match args.npy { Some(filename) => { - let tensors = Tensor::read_npz(filename)? - .into_iter() - .map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?))) - .collect::>>()?; - let vb = VarBuilder::from_tensors(tensors, DTYPE, &device); + let vb = VarBuilder::from_npz(filename, DTYPE, &device)?; let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); (Llama::load(vb, &cache, &config)?, tokenizer) } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 6d79bddd..7f68ae08 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,4 +1,4 @@ -use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor}; +use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor}; use std::collections::HashMap; use std::sync::Arc; @@ -9,6 +9,7 @@ enum Tensors<'a> { routing: HashMap, safetensors: Vec>, }, + Npz(candle::npy::NpzTensors), TensorMap(HashMap), Zeros, } @@ -53,6 +54,15 @@ impl<'a> TensorData<'a> { dtype, } } + + fn from_npz>(file: P, dtype: DType, device: &Device) -> Result { + let npz = candle::npy::NpzTensors::new(file)?; + Ok(Self { + tensors: Tensors::Npz(npz), + device: device.clone(), + dtype, + }) + } } #[derive(Clone)] @@ -88,6 +98,18 @@ impl<'a> VarBuilder<'a> { } } + pub fn from_npz>( + file: P, + dtype: DType, + device: &Device, + ) -> Result { + let data = TensorData::from_npz(file, dtype, device)?; + Ok(Self { + data: Arc::new(data), + path: vec![], + }) + } + pub fn push_prefix(&self, s: &str) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); @@ -112,7 +134,7 @@ impl<'a> VarBuilder<'a> { } impl<'a> VarBuilder<'a> { - pub fn get>(&self, s: S, tensor_name: &str) -> candle::Result { + pub fn get>(&self, s: S, tensor_name: &str) -> Result { let data = self.data.as_ref(); let s: Shape = s.into(); let path = if self.path.is_empty() { @@ -128,6 +150,9 @@ impl<'a> VarBuilder<'a> { path: path.to_string(), })? .clone(), + Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor { + path: path.to_string(), + })?, Tensors::SafeTensorWithRouting { routing, safetensors,