Multiprocess/multi-GPU support for llama 3. (#2092)

* Multiprocess/multi-GPU support for llama 3.

* Modernize the mp example a bit.
This commit is contained in:
Laurent Mazare
2024-04-20 12:49:21 +02:00
committed by GitHub
parent b45c710dbf
commit c97d639fa0
2 changed files with 123 additions and 135 deletions

View File

@ -10,7 +10,7 @@
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use clap::Parser;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
@ -24,57 +24,15 @@ mod model;
use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no
From Clifford's and Northumberland's pursuit:
Had he been ta'en, we should have heard the news;
Had he been slain, we should have heard the news;
Or had he 'scaped, methinks we should have heard
The happy tidings of his good escape.
How fares my brother? why is he so sad?
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
RICHARD:
I cannot joy, until I be resolved
Where our right valiant father is become.
I saw him in the battle range about;
And watch'd him how he singled Clifford forth.
Methought he bore him in the thickest troop
As doth a lion in a herd of neat;
Or as a bear, encompass'd round with dogs,
Who having pinch'd a few and made them cry,
The rest stand all aloof, and bark at him.
So fared our father with his enemies;
So fled his enemies my warlike father:
Methinks, 'tis prize enough to be his son.
See how the morning opes her golden gates,
And takes her farewell of the glorious sun!
How well resembles it the prime of youth,
Trimm'd like a younker prancing to his love!
EDWARD:
Dazzle mine eyes, or do I see three suns?
RICHARD:
Three glorious suns, each one a perfect sun;
Not separated with the racking clouds,
But sever'd in a pale clear-shining sky.
See, see! they join, embrace, and seem to kiss,
As if they vow'd some league inviolable:
Now are they but one lamp, one light, one sun.
In this the heaven figures some event.
EDWARD:
'Tis wondrous strange, the like yet never heard of.
I think it cites us, brother, to the field,
That we, the sons of brave Plantagenet,
Each one already blazing by our meeds,
Should notwithstanding join our lights together
And over-shine the earth as this the world.
Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
V2_7b,
V2_70b,
V3_8b,
V3_70b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@ -86,8 +44,8 @@ struct Args {
rank: Option<usize>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
@ -117,6 +75,12 @@ struct Args {
#[arg(long)]
dtype: Option<String>,
#[arg(long)]
which: Which,
#[arg(long, default_value = "nccl_id.txt")]
comm_file: String,
}
fn main() -> Result<()> {
@ -129,14 +93,47 @@ fn main() -> Result<()> {
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
None => match args.which {
Which::V2_7b | Which::V2_70b => DType::F16,
Which::V3_8b | Which::V3_70b => DType::BF16,
},
};
let comm_file = std::path::PathBuf::from(&args.comm_file);
if comm_file.exists() {
bail!("comm file {comm_file:?} already exists, please remove it first")
}
let rank = match args.rank {
None => {
println!("creating {} child processes", args.num_shards);
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait()?;
}
return Ok(());
}
Some(rank) => rank,
};
let api = Api::new()?;
let model_id = args
.model_id
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
let model_id = match args.model_id {
Some(model) => model,
None => match args.which {
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
},
};
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
@ -145,39 +142,20 @@ fn main() -> Result<()> {
let tokenizer_filename = api.get("tokenizer.json")?;
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
if args.rank.is_none() {
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait().unwrap();
}
return Ok(());
}
let i = args.rank.unwrap();
let num_shards = args.num_shards;
let rank = i;
// Primitive IPC
let id = if rank == 0 {
let id = Id::new().unwrap();
std::fs::File::create("nccl_id.txt.tmp")?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
.unwrap();
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
let tmp_file = comm_file.with_extension(".comm.tgz");
std::fs::File::create(&tmp_file)?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
std::fs::rename(&tmp_file, &comm_file)?;
id
} else {
let path = std::path::PathBuf::from("nccl_id.txt");
while !path.exists() {
while !comm_file.exists() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
let data = std::fs::read("nccl_id.txt")?;
let data = std::fs::read(&comm_file)?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
@ -187,14 +165,17 @@ fn main() -> Result<()> {
let id: Id = Id::uninit(internal);
id
};
let device = CudaDevice::new(i)?;
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
let device = CudaDevice::new(rank)?;
let comm = match Comm::from_rank(device, rank, num_shards, id) {
Ok(comm) => Rc::new(comm),
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
};
if rank == 0 {
std::fs::remove_file("nccl_id.txt")?;
std::fs::remove_file(comm_file)?;
}
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let device = Device::new_cuda(rank)?;
let cache = model::Cache::new(dtype, &config, &device)?;
println!("building the model");
@ -210,14 +191,24 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
// Only start timing at the second token as processing the first token waits for all the
// weights to be loaded in an async way.
if index == 1 {
start_gen = std::time::Instant::now()
};
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
@ -229,24 +220,18 @@ fn main() -> Result<()> {
tokens.push(next_token);
new_tokens.push(next_token);
if rank == 0 {
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
}
let dt = start_gen.elapsed();
if rank == 0 {
let dt = start_gen.elapsed();
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
"\n\n{} tokens generated ({} token/s)\n",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer
.decode(new_tokens.as_slice(), true)
.map_err(E::msg)?
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
);
}
Ok(())

View File

@ -2,7 +2,7 @@ use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use half::{bf16, f16};
use serde::Deserialize;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
@ -58,15 +58,31 @@ impl CustomOp1 for AllReduce {
use candle::cuda_backend::WrapErr;
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
match s.dtype() {
DType::BF16 => {
let s = s.as_cuda_slice::<bf16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
}
DType::F16 => {
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
}
dtype => candle::bail!("unsupported dtype {dtype:?}"),
}
}
}
@ -161,7 +177,6 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
@ -197,16 +212,10 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
candle_nn::rotary_emb::rope(x, &cos, &sin)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
@ -232,13 +241,16 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
@ -278,16 +290,7 @@ impl CausalSelfAttention {
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.num_attention_heads / self.num_key_value_heads;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
candle_transformers::utils::repeat_kv(x, n_rep)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {