Use yoke to provide a self-referential container for mmaped safetenso… (#939)

* Use yoke to provide a self-referential container for mmaped safetensor files.

* Add the new self-owned type for safetensor files without removing the previous version.

* Add routing.

* Add an initializer for the case of multiple files.
This commit is contained in:
Laurent Mazare
2023-09-23 15:43:11 +01:00
committed by GitHub
parent 402d207f0f
commit ccf352f3d1
5 changed files with 104 additions and 15 deletions

View File

@ -1,4 +1,4 @@
use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var};
use candle::{DType, Device, Result, Shape, Tensor, Var};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -40,18 +40,12 @@ impl VarMap {
/// Note that values for variables that are currently not in the map are not kept.
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
let path = path.as_ref();
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
let data = data.deserialize()?;
let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? };
let mut tensor_data = self.data.lock().unwrap();
for (name, var) in tensor_data.iter_mut() {
match data.tensor(name) {
Ok(data) => {
let data: Tensor = data.load(var.device())?;
if let Err(err) = var.set(&data) {
candle::bail!("error setting {name} using data from {path:?}: {err}",)
}
}
Err(_) => candle::bail!("cannot find tensor for {name}"),
let data = data.load(name, var.device())?;
if let Err(err) = var.set(&data) {
candle::bail!("error setting {name} using data from {path:?}: {err}",)
}
}
Ok(())