From 8b390ddd290cfdcff8ef319d266b9d466838494f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 20 Apr 2024 13:01:23 +0200 Subject: [PATCH] Only download the weights in the main process (and not in the child processes). (#2093) --- .../examples/llama_multiprocess/main.rs | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index 2b914cee..3b03b873 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -104,6 +104,24 @@ fn main() -> Result<()> { bail!("comm file {comm_file:?} already exists, please remove it first") } + let api = Api::new()?; + let model_id = match args.model_id { + Some(model) => model, + None => match args.which { + Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(), + Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(), + Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(), + Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(), + }, + }; + println!("loading the model weights from {model_id}"); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let config_filename = api.get("config.json")?; + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let tokenizer_filename = api.get("tokenizer.json")?; + let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; + let rank = match args.rank { None => { println!("creating {} child processes", args.num_shards); @@ -124,24 +142,6 @@ fn main() -> Result<()> { Some(rank) => rank, }; - let api = Api::new()?; - let model_id = match args.model_id { - Some(model) => model, - None => match args.which { - Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(), - Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(), - Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(), - Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(), - }, - }; - println!("loading the model weights from {model_id}"); - let revision = args.revision.unwrap_or("main".to_string()); - let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - let config_filename = api.get("config.json")?; - let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; - let tokenizer_filename = api.get("tokenizer.json")?; - let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; - let num_shards = args.num_shards; // Primitive IPC let id = if rank == 0 {