mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Helper function to load sharded safetensors files (#1481)
* Fix the quantized mistral example. * Add a helper function to load sharded safetensors weights. * Use the sharded loader.
This commit is contained in:
@ -235,10 +235,7 @@ fn main() -> Result<()> {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
vec![
|
||||
repo.get("model-00001-of-00002.safetensors")?,
|
||||
repo.get("model-00002-of-00002.safetensors")?,
|
||||
]
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -209,29 +209,7 @@ fn main() -> Result<()> {
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![
|
||||
repo.get("model-00001-of-00019.safetensors")?,
|
||||
repo.get("model-00002-of-00019.safetensors")?,
|
||||
repo.get("model-00003-of-00019.safetensors")?,
|
||||
repo.get("model-00004-of-00019.safetensors")?,
|
||||
repo.get("model-00005-of-00019.safetensors")?,
|
||||
repo.get("model-00006-of-00019.safetensors")?,
|
||||
repo.get("model-00007-of-00019.safetensors")?,
|
||||
repo.get("model-00008-of-00019.safetensors")?,
|
||||
repo.get("model-00009-of-00019.safetensors")?,
|
||||
repo.get("model-00010-of-00019.safetensors")?,
|
||||
repo.get("model-00011-of-00019.safetensors")?,
|
||||
repo.get("model-00012-of-00019.safetensors")?,
|
||||
repo.get("model-00013-of-00019.safetensors")?,
|
||||
repo.get("model-00014-of-00019.safetensors")?,
|
||||
repo.get("model-00015-of-00019.safetensors")?,
|
||||
repo.get("model-00016-of-00019.safetensors")?,
|
||||
repo.get("model-00017-of-00019.safetensors")?,
|
||||
repo.get("model-00018-of-00019.safetensors")?,
|
||||
repo.get("model-00019-of-00019.safetensors")?,
|
||||
]
|
||||
}
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
@ -278,10 +278,10 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 => vec![
|
||||
repo.get("model-00001-of-00002.safetensors")?,
|
||||
repo.get("model-00002-of-00002.safetensors")?,
|
||||
],
|
||||
WhichModel::V2 => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||
}
|
||||
|
@ -96,25 +96,9 @@ impl T5ModelBuilder {
|
||||
let api = api.repo(repo);
|
||||
let config_filename = api.get("config.json")?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
||||
vec![
|
||||
api.get("model-00001-of-00005.safetensors")?,
|
||||
api.get("model-00002-of-00005.safetensors")?,
|
||||
api.get("model-00003-of-00005.safetensors")?,
|
||||
api.get("model-00004-of-00005.safetensors")?,
|
||||
api.get("model-00005-of-00005.safetensors")?,
|
||||
]
|
||||
} else if model_id == "google/flan-ul2" {
|
||||
vec![
|
||||
api.get("model-00001-of-00008.safetensors")?,
|
||||
api.get("model-00002-of-00008.safetensors")?,
|
||||
api.get("model-00003-of-00008.safetensors")?,
|
||||
api.get("model-00004-of-00008.safetensors")?,
|
||||
api.get("model-00005-of-00008.safetensors")?,
|
||||
api.get("model-00006-of-00008.safetensors")?,
|
||||
api.get("model-00007-of-00008.safetensors")?,
|
||||
api.get("model-00008-of-00008.safetensors")?,
|
||||
]
|
||||
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
||||
{
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
} else {
|
||||
vec![api.get("model.safetensors")?]
|
||||
};
|
||||
|
@ -218,21 +218,7 @@ fn main() -> Result<()> {
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.which {
|
||||
Which::L6b => vec![
|
||||
repo.get("model-00001-of-00002.safetensors")?,
|
||||
repo.get("model-00002-of-00002.safetensors")?,
|
||||
],
|
||||
Which::L34b => vec![
|
||||
repo.get("model-00001-of-00007.safetensors")?,
|
||||
repo.get("model-00002-of-00007.safetensors")?,
|
||||
repo.get("model-00003-of-00007.safetensors")?,
|
||||
repo.get("model-00004-of-00007.safetensors")?,
|
||||
repo.get("model-00005-of-00007.safetensors")?,
|
||||
repo.get("model-00006-of-00007.safetensors")?,
|
||||
repo.get("model-00007-of-00007.safetensors")?,
|
||||
],
|
||||
},
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
Reference in New Issue
Block a user