mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Only download the weights in the main process (and not in the child processes). (#2093)
This commit is contained in:
@ -104,6 +104,24 @@ fn main() -> Result<()> {
|
||||
bail!("comm file {comm_file:?} already exists, please remove it first")
|
||||
}
|
||||
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model) => model,
|
||||
None => match args.which {
|
||||
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
||||
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
||||
},
|
||||
};
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
|
||||
let rank = match args.rank {
|
||||
None => {
|
||||
println!("creating {} child processes", args.num_shards);
|
||||
@ -124,24 +142,6 @@ fn main() -> Result<()> {
|
||||
Some(rank) => rank,
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model) => model,
|
||||
None => match args.which {
|
||||
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
||||
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
||||
},
|
||||
};
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
|
||||
let num_shards = args.num_shards;
|
||||
// Primitive IPC
|
||||
let id = if rank == 0 {
|
||||
|
Reference in New Issue
Block a user