mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Multiprocess/multi-GPU support for llama 3. (#2092)
* Multiprocess/multi-GPU support for llama 3. * Modernize the mp example a bit.
This commit is contained in:
@ -10,7 +10,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
@ -24,57 +24,15 @@ mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
Or whether he be 'scaped away or no
|
||||
From Clifford's and Northumberland's pursuit:
|
||||
Had he been ta'en, we should have heard the news;
|
||||
Had he been slain, we should have heard the news;
|
||||
Or had he 'scaped, methinks we should have heard
|
||||
The happy tidings of his good escape.
|
||||
How fares my brother? why is he so sad?
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
RICHARD:
|
||||
I cannot joy, until I be resolved
|
||||
Where our right valiant father is become.
|
||||
I saw him in the battle range about;
|
||||
And watch'd him how he singled Clifford forth.
|
||||
Methought he bore him in the thickest troop
|
||||
As doth a lion in a herd of neat;
|
||||
Or as a bear, encompass'd round with dogs,
|
||||
Who having pinch'd a few and made them cry,
|
||||
The rest stand all aloof, and bark at him.
|
||||
So fared our father with his enemies;
|
||||
So fled his enemies my warlike father:
|
||||
Methinks, 'tis prize enough to be his son.
|
||||
See how the morning opes her golden gates,
|
||||
And takes her farewell of the glorious sun!
|
||||
How well resembles it the prime of youth,
|
||||
Trimm'd like a younker prancing to his love!
|
||||
|
||||
EDWARD:
|
||||
Dazzle mine eyes, or do I see three suns?
|
||||
|
||||
RICHARD:
|
||||
Three glorious suns, each one a perfect sun;
|
||||
Not separated with the racking clouds,
|
||||
But sever'd in a pale clear-shining sky.
|
||||
See, see! they join, embrace, and seem to kiss,
|
||||
As if they vow'd some league inviolable:
|
||||
Now are they but one lamp, one light, one sun.
|
||||
In this the heaven figures some event.
|
||||
|
||||
EDWARD:
|
||||
'Tis wondrous strange, the like yet never heard of.
|
||||
I think it cites us, brother, to the field,
|
||||
That we, the sons of brave Plantagenet,
|
||||
Each one already blazing by our meeds,
|
||||
Should notwithstanding join our lights together
|
||||
And over-shine the earth as this the world.
|
||||
Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
V2_7b,
|
||||
V2_70b,
|
||||
V3_8b,
|
||||
V3_70b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -86,8 +44,8 @@ struct Args {
|
||||
rank: Option<usize>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
@ -117,6 +75,12 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
which: Which,
|
||||
|
||||
#[arg(long, default_value = "nccl_id.txt")]
|
||||
comm_file: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -129,14 +93,47 @@ fn main() -> Result<()> {
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
None => match args.which {
|
||||
Which::V2_7b | Which::V2_70b => DType::F16,
|
||||
Which::V3_8b | Which::V3_70b => DType::BF16,
|
||||
},
|
||||
};
|
||||
|
||||
let comm_file = std::path::PathBuf::from(&args.comm_file);
|
||||
if comm_file.exists() {
|
||||
bail!("comm file {comm_file:?} already exists, please remove it first")
|
||||
}
|
||||
|
||||
let rank = match args.rank {
|
||||
None => {
|
||||
println!("creating {} child processes", args.num_shards);
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait()?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
Some(rank) => rank,
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = args
|
||||
.model_id
|
||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||
let model_id = match args.model_id {
|
||||
Some(model) => model,
|
||||
None => match args.which {
|
||||
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
||||
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
||||
},
|
||||
};
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
@ -145,39 +142,20 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
|
||||
if args.rank.is_none() {
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait().unwrap();
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let i = args.rank.unwrap();
|
||||
let num_shards = args.num_shards;
|
||||
let rank = i;
|
||||
// Primitive IPC
|
||||
let id = if rank == 0 {
|
||||
let id = Id::new().unwrap();
|
||||
std::fs::File::create("nccl_id.txt.tmp")?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
|
||||
.unwrap();
|
||||
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
|
||||
let tmp_file = comm_file.with_extension(".comm.tgz");
|
||||
std::fs::File::create(&tmp_file)?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
|
||||
std::fs::rename(&tmp_file, &comm_file)?;
|
||||
id
|
||||
} else {
|
||||
let path = std::path::PathBuf::from("nccl_id.txt");
|
||||
while !path.exists() {
|
||||
while !comm_file.exists() {
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
}
|
||||
let data = std::fs::read("nccl_id.txt")?;
|
||||
let data = std::fs::read(&comm_file)?;
|
||||
let internal: [i8; 128] = data
|
||||
.into_iter()
|
||||
.map(|i| i as i8)
|
||||
@ -187,14 +165,17 @@ fn main() -> Result<()> {
|
||||
let id: Id = Id::uninit(internal);
|
||||
id
|
||||
};
|
||||
let device = CudaDevice::new(i)?;
|
||||
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
|
||||
let device = CudaDevice::new(rank)?;
|
||||
let comm = match Comm::from_rank(device, rank, num_shards, id) {
|
||||
Ok(comm) => Rc::new(comm),
|
||||
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
|
||||
};
|
||||
if rank == 0 {
|
||||
std::fs::remove_file("nccl_id.txt")?;
|
||||
std::fs::remove_file(comm_file)?;
|
||||
}
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(i)?;
|
||||
let device = Device::new_cuda(rank)?;
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
@ -210,14 +191,24 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
// Only start timing at the second token as processing the first token waits for all the
|
||||
// weights to be loaded in an async way.
|
||||
if index == 1 {
|
||||
start_gen = std::time::Instant::now()
|
||||
};
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
@ -229,24 +220,18 @@ fn main() -> Result<()> {
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
if rank == 0 {
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if rank == 0 {
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer
|
||||
.decode(new_tokens.as_slice(), true)
|
||||
.map_err(E::msg)?
|
||||
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user