diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 8ae828bd..bbafca0b 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -17,8 +17,11 @@ candle-nn = { path = "../candle-nn", version = "0.3.2" } candle-transformers = { path = "../candle-transformers", version = "0.3.2" } candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.2", optional = true } candle-onnx = { path = "../candle-onnx", version = "0.3.2", optional = true } + +csv = "1.3.0" cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } +hf-hub = { workspace = true, features=["tokio"]} image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } @@ -28,13 +31,11 @@ safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } -csv = "1.3.0" [dev-dependencies] anyhow = { workspace = true } byteorder = { workspace = true } clap = { workspace = true } -hf-hub = { workspace = true, features=["tokio"]} imageproc = { workspace = true } memmap2 = { workspace = true } rand = { workspace = true } diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 2b31142e..5ed5e5cb 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -235,10 +235,7 @@ fn main() -> Result<()> { if args.quantized { vec![repo.get("model-q4k.gguf")?] } else { - vec![ - repo.get("model-00001-of-00002.safetensors")?, - repo.get("model-00002-of-00002.safetensors")?, - ] + candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } } }; diff --git a/candle-examples/examples/mixtral/main.rs b/candle-examples/examples/mixtral/main.rs index fcde03c1..1b1a4b36 100644 --- a/candle-examples/examples/mixtral/main.rs +++ b/candle-examples/examples/mixtral/main.rs @@ -209,29 +209,7 @@ fn main() -> Result<()> { .split(',') .map(std::path::PathBuf::from) .collect::>(), - None => { - vec![ - repo.get("model-00001-of-00019.safetensors")?, - repo.get("model-00002-of-00019.safetensors")?, - repo.get("model-00003-of-00019.safetensors")?, - repo.get("model-00004-of-00019.safetensors")?, - repo.get("model-00005-of-00019.safetensors")?, - repo.get("model-00006-of-00019.safetensors")?, - repo.get("model-00007-of-00019.safetensors")?, - repo.get("model-00008-of-00019.safetensors")?, - repo.get("model-00009-of-00019.safetensors")?, - repo.get("model-00010-of-00019.safetensors")?, - repo.get("model-00011-of-00019.safetensors")?, - repo.get("model-00012-of-00019.safetensors")?, - repo.get("model-00013-of-00019.safetensors")?, - repo.get("model-00014-of-00019.safetensors")?, - repo.get("model-00015-of-00019.safetensors")?, - repo.get("model-00016-of-00019.safetensors")?, - repo.get("model-00017-of-00019.safetensors")?, - repo.get("model-00018-of-00019.safetensors")?, - repo.get("model-00019-of-00019.safetensors")?, - ] - } + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 3574b1f2..c529867b 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -278,10 +278,10 @@ fn main() -> Result<()> { } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], - WhichModel::V2 => vec![ - repo.get("model-00001-of-00002.safetensors")?, - repo.get("model-00002-of-00002.safetensors")?, - ], + WhichModel::V2 => candle_examples::hub_load_safetensors( + &repo, + "model.safetensors.index.json", + )?, WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?], } diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 6a446615..8ef108b6 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -96,25 +96,9 @@ impl T5ModelBuilder { let api = api.repo(repo); let config_filename = api.get("config.json")?; let tokenizer_filename = api.get("tokenizer.json")?; - let weights_filename = if model_id == "google/flan-t5-xxl" { - vec![ - api.get("model-00001-of-00005.safetensors")?, - api.get("model-00002-of-00005.safetensors")?, - api.get("model-00003-of-00005.safetensors")?, - api.get("model-00004-of-00005.safetensors")?, - api.get("model-00005-of-00005.safetensors")?, - ] - } else if model_id == "google/flan-ul2" { - vec![ - api.get("model-00001-of-00008.safetensors")?, - api.get("model-00002-of-00008.safetensors")?, - api.get("model-00003-of-00008.safetensors")?, - api.get("model-00004-of-00008.safetensors")?, - api.get("model-00005-of-00008.safetensors")?, - api.get("model-00006-of-00008.safetensors")?, - api.get("model-00007-of-00008.safetensors")?, - api.get("model-00008-of-00008.safetensors")?, - ] + let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" + { + candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } else { vec![api.get("model.safetensors")?] }; diff --git a/candle-examples/examples/yi/main.rs b/candle-examples/examples/yi/main.rs index a7184db9..e4cbfc6f 100644 --- a/candle-examples/examples/yi/main.rs +++ b/candle-examples/examples/yi/main.rs @@ -218,21 +218,7 @@ fn main() -> Result<()> { .split(',') .map(std::path::PathBuf::from) .collect::>(), - None => match args.which { - Which::L6b => vec![ - repo.get("model-00001-of-00002.safetensors")?, - repo.get("model-00002-of-00002.safetensors")?, - ], - Which::L34b => vec![ - repo.get("model-00001-of-00007.safetensors")?, - repo.get("model-00002-of-00007.safetensors")?, - repo.get("model-00003-of-00007.safetensors")?, - repo.get("model-00004-of-00007.safetensors")?, - repo.get("model-00005-of-00007.safetensors")?, - repo.get("model-00006-of-00007.safetensors")?, - repo.get("model-00007-of-00007.safetensors")?, - ], - }, + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index dff31b85..d6dce4a3 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -117,3 +117,30 @@ pub fn save_image_resize>( image.save(p).map_err(candle::Error::wrap)?; Ok(()) } + +/// Loads the safetensors files for a model from the hub based on a json index file. +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> Result> { + let json_file = repo.get(json_file).map_err(candle::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| repo.get(v).map_err(candle::Error::wrap)) + .collect::>>()?; + Ok(safetensors_files) +}