Helper function to load sharded safetensors files (#1481)

* Fix the quantized mistral example.

* Add a helper function to load sharded safetensors weights.

* Use the sharded loader.
This commit is contained in:
Laurent Mazare
2023-12-25 21:49:21 +01:00
committed by GitHub
parent eae3a20d43
commit 37c539f2b7
7 changed files with 40 additions and 67 deletions

View File

@ -117,3 +117,30 @@ pub fn save_image_resize<P: AsRef<std::path::Path>>(
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<Result<Vec<_>>>()?;
Ok(safetensors_files)
}