mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use yoke to provide a self-referential container for mmaped safetenso… (#939)
* Use yoke to provide a self-referential container for mmaped safetensor files. * Add the new self-owned type for safetensor files without removing the previous version. * Add routing. * Add an initializer for the case of multiple files.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var};
|
||||
use candle::{DType, Device, Result, Shape, Tensor, Var};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -40,18 +40,12 @@ impl VarMap {
|
||||
/// Note that values for variables that are currently not in the map are not kept.
|
||||
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
||||
let path = path.as_ref();
|
||||
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
||||
let data = data.deserialize()?;
|
||||
let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? };
|
||||
let mut tensor_data = self.data.lock().unwrap();
|
||||
for (name, var) in tensor_data.iter_mut() {
|
||||
match data.tensor(name) {
|
||||
Ok(data) => {
|
||||
let data: Tensor = data.load(var.device())?;
|
||||
if let Err(err) = var.set(&data) {
|
||||
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||
}
|
||||
}
|
||||
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
||||
let data = data.load(name, var.device())?;
|
||||
if let Err(err) = var.set(&data) {
|
||||
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user