Helper function to load sharded safetensors files (#1481)

* Fix the quantized mistral example.

* Add a helper function to load sharded safetensors weights.

* Use the sharded loader.
This commit is contained in:
Laurent Mazare
2023-12-25 21:49:21 +01:00
committed by GitHub
parent eae3a20d43
commit 37c539f2b7
7 changed files with 40 additions and 67 deletions

View File

@ -218,21 +218,7 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::L6b => vec![
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
],
Which::L34b => vec![
repo.get("model-00001-of-00007.safetensors")?,
repo.get("model-00002-of-00007.safetensors")?,
repo.get("model-00003-of-00007.safetensors")?,
repo.get("model-00004-of-00007.safetensors")?,
repo.get("model-00005-of-00007.safetensors")?,
repo.get("model-00006-of-00007.safetensors")?,
repo.get("model-00007-of-00007.safetensors")?,
],
},
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;