mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Only download the weights in the main process (and not in the child processes). (#2093)
This commit is contained in:
@ -104,6 +104,24 @@ fn main() -> Result<()> {
|
|||||||
bail!("comm file {comm_file:?} already exists, please remove it first")
|
bail!("comm file {comm_file:?} already exists, please remove it first")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model) => model,
|
||||||
|
None => match args.which {
|
||||||
|
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
|
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
||||||
|
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||||
|
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
println!("loading the model weights from {model_id}");
|
||||||
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let config_filename = api.get("config.json")?;
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
|
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||||
|
|
||||||
let rank = match args.rank {
|
let rank = match args.rank {
|
||||||
None => {
|
None => {
|
||||||
println!("creating {} child processes", args.num_shards);
|
println!("creating {} child processes", args.num_shards);
|
||||||
@ -124,24 +142,6 @@ fn main() -> Result<()> {
|
|||||||
Some(rank) => rank,
|
Some(rank) => rank,
|
||||||
};
|
};
|
||||||
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model) => model,
|
|
||||||
None => match args.which {
|
|
||||||
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
|
||||||
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
|
||||||
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
|
||||||
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
println!("loading the model weights from {model_id}");
|
|
||||||
let revision = args.revision.unwrap_or("main".to_string());
|
|
||||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
|
||||||
let config_filename = api.get("config.json")?;
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
|
||||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
|
||||||
|
|
||||||
let num_shards = args.num_shards;
|
let num_shards = args.num_shards;
|
||||||
// Primitive IPC
|
// Primitive IPC
|
||||||
let id = if rank == 0 {
|
let id = if rank == 0 {
|
||||||
|
Reference in New Issue
Block a user