Helper function to load sharded safetensors files (#1481)

* Fix the quantized mistral example.

* Add a helper function to load sharded safetensors weights.

* Use the sharded loader.
This commit is contained in:
Laurent Mazare
2023-12-25 21:49:21 +01:00
committed by GitHub
parent eae3a20d43
commit 37c539f2b7
7 changed files with 40 additions and 67 deletions

View File

@ -96,25 +96,9 @@ impl T5ModelBuilder {
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" {
vec![
api.get("model-00001-of-00005.safetensors")?,
api.get("model-00002-of-00005.safetensors")?,
api.get("model-00003-of-00005.safetensors")?,
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
{
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
} else {
vec![api.get("model.safetensors")?]
};