mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Allow for lazy loading of npz files, use it in llama to reduce memory usage in the cpu version. (#141)
This commit is contained in:
@ -251,7 +251,7 @@ impl Tensor {
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut result = vec![];
|
||||
for i in 0..zip.len() {
|
||||
let mut reader = zip.by_index(i).unwrap();
|
||||
let mut reader = zip.by_index(i)?;
|
||||
let name = {
|
||||
let name = reader.name();
|
||||
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
|
||||
@ -368,6 +368,53 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Lazy tensor loader.
|
||||
pub struct NpzTensors {
|
||||
index_per_name: HashMap<String, usize>,
|
||||
path: std::path::PathBuf,
|
||||
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||
// re-create a zip reader each time.
|
||||
}
|
||||
|
||||
impl NpzTensors {
|
||||
pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
|
||||
let path = path.as_ref().to_owned();
|
||||
let zip_reader = BufReader::new(File::open(&path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut index_per_name = HashMap::new();
|
||||
for i in 0..zip.len() {
|
||||
let file = zip.by_index(i)?;
|
||||
let name = {
|
||||
let name = file.name();
|
||||
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
|
||||
};
|
||||
index_per_name.insert(name, i);
|
||||
}
|
||||
Ok(Self {
|
||||
index_per_name,
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
let index = match self.index_per_name.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(index) => *index,
|
||||
};
|
||||
// We hope that the file has not changed since first reading it.
|
||||
let zip_reader = BufReader::new(File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_index(index)?;
|
||||
let header = read_header(&mut reader)?;
|
||||
let header = Header::parse(&header)?;
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Header;
|
||||
|
Reference in New Issue
Block a user