From c97d639fa0a51af3b89415e149db279ed7e686f9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 20 Apr 2024 12:49:21 +0200 Subject: [PATCH] Multiprocess/multi-GPU support for llama 3. (#2092) * Multiprocess/multi-GPU support for llama 3. * Modernize the mp example a bit. --- .../examples/llama_multiprocess/main.rs | 191 ++++++++---------- .../examples/llama_multiprocess/model.rs | 67 +++--- 2 files changed, 123 insertions(+), 135 deletions(-) diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index bc158817..2b914cee 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -10,7 +10,7 @@ extern crate intel_mkl_src; use anyhow::{bail, Error as E, Result}; -use clap::Parser; +use clap::{Parser, ValueEnum}; use candle::{DType, Device, Tensor}; use candle_transformers::generation::LogitsProcessor; @@ -24,57 +24,15 @@ mod model; use model::{Config, Llama}; const MAX_SEQ_LEN: usize = 4096; -const DEFAULT_PROMPT: &str = r" -EDWARD: -I wonder how our princely father 'scaped, -Or whether he be 'scaped away or no -From Clifford's and Northumberland's pursuit: -Had he been ta'en, we should have heard the news; -Had he been slain, we should have heard the news; -Or had he 'scaped, methinks we should have heard -The happy tidings of his good escape. -How fares my brother? why is he so sad? +const DEFAULT_PROMPT: &str = "My favorite theorem is "; -RICHARD: -I cannot joy, until I be resolved -Where our right valiant father is become. -I saw him in the battle range about; -And watch'd him how he singled Clifford forth. -Methought he bore him in the thickest troop -As doth a lion in a herd of neat; -Or as a bear, encompass'd round with dogs, -Who having pinch'd a few and made them cry, -The rest stand all aloof, and bark at him. -So fared our father with his enemies; -So fled his enemies my warlike father: -Methinks, 'tis prize enough to be his son. -See how the morning opes her golden gates, -And takes her farewell of the glorious sun! -How well resembles it the prime of youth, -Trimm'd like a younker prancing to his love! - -EDWARD: -Dazzle mine eyes, or do I see three suns? - -RICHARD: -Three glorious suns, each one a perfect sun; -Not separated with the racking clouds, -But sever'd in a pale clear-shining sky. -See, see! they join, embrace, and seem to kiss, -As if they vow'd some league inviolable: -Now are they but one lamp, one light, one sun. -In this the heaven figures some event. - -EDWARD: -'Tis wondrous strange, the like yet never heard of. -I think it cites us, brother, to the field, -That we, the sons of brave Plantagenet, -Each one already blazing by our meeds, -Should notwithstanding join our lights together -And over-shine the earth as this the world. -Whate'er it bodes, henceforward will I bear -Upon my target three fair-shining suns. -"; +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + V2_7b, + V2_70b, + V3_8b, + V3_70b, +} #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -86,8 +44,8 @@ struct Args { rank: Option, /// The temperature used to generate samples. - #[arg(long)] - temperature: Option, + #[arg(long, default_value_t = 0.8)] + temperature: f64, /// Nucleus sampling probability cutoff. #[arg(long)] @@ -117,6 +75,12 @@ struct Args { #[arg(long)] dtype: Option, + + #[arg(long)] + which: Which, + + #[arg(long, default_value = "nccl_id.txt")] + comm_file: String, } fn main() -> Result<()> { @@ -129,14 +93,47 @@ fn main() -> Result<()> { Some("bf16") => DType::BF16, Some("f32") => DType::F32, Some(dtype) => bail!("Unsupported dtype {dtype}"), - None => DType::F16, + None => match args.which { + Which::V2_7b | Which::V2_70b => DType::F16, + Which::V3_8b | Which::V3_70b => DType::BF16, + }, + }; + + let comm_file = std::path::PathBuf::from(&args.comm_file); + if comm_file.exists() { + bail!("comm file {comm_file:?} already exists, please remove it first") + } + + let rank = match args.rank { + None => { + println!("creating {} child processes", args.num_shards); + let children: Vec<_> = (0..args.num_shards) + .map(|rank| { + let mut args: std::collections::VecDeque<_> = std::env::args().collect(); + args.push_back("--rank".to_string()); + args.push_back(format!("{rank}")); + let name = args.pop_front().unwrap(); + std::process::Command::new(name).args(args).spawn().unwrap() + }) + .collect(); + for mut child in children { + child.wait()?; + } + return Ok(()); + } + Some(rank) => rank, }; let api = Api::new()?; - - let model_id = args - .model_id - .unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string()); + let model_id = match args.model_id { + Some(model) => model, + None => match args.which { + Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(), + Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(), + Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(), + Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(), + }, + }; println!("loading the model weights from {model_id}"); let revision = args.revision.unwrap_or("main".to_string()); let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); @@ -145,39 +142,20 @@ fn main() -> Result<()> { let tokenizer_filename = api.get("tokenizer.json")?; let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; - if args.rank.is_none() { - let children: Vec<_> = (0..args.num_shards) - .map(|rank| { - let mut args: std::collections::VecDeque<_> = std::env::args().collect(); - args.push_back("--rank".to_string()); - args.push_back(format!("{rank}")); - let name = args.pop_front().unwrap(); - std::process::Command::new(name).args(args).spawn().unwrap() - }) - .collect(); - for mut child in children { - child.wait().unwrap(); - } - return Ok(()); - } - - let i = args.rank.unwrap(); let num_shards = args.num_shards; - let rank = i; // Primitive IPC let id = if rank == 0 { let id = Id::new().unwrap(); - std::fs::File::create("nccl_id.txt.tmp")? - .write_all(&id.internal().iter().map(|&i| i as u8).collect::>()) - .unwrap(); - std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?; + let tmp_file = comm_file.with_extension(".comm.tgz"); + std::fs::File::create(&tmp_file)? + .write_all(&id.internal().iter().map(|&i| i as u8).collect::>())?; + std::fs::rename(&tmp_file, &comm_file)?; id } else { - let path = std::path::PathBuf::from("nccl_id.txt"); - while !path.exists() { + while !comm_file.exists() { std::thread::sleep(std::time::Duration::from_secs(1)); } - let data = std::fs::read("nccl_id.txt")?; + let data = std::fs::read(&comm_file)?; let internal: [i8; 128] = data .into_iter() .map(|i| i as i8) @@ -187,14 +165,17 @@ fn main() -> Result<()> { let id: Id = Id::uninit(internal); id }; - let device = CudaDevice::new(i)?; - let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap()); + let device = CudaDevice::new(rank)?; + let comm = match Comm::from_rank(device, rank, num_shards, id) { + Ok(comm) => Rc::new(comm), + Err(err) => anyhow::bail!("nccl error {:?}", err.0), + }; if rank == 0 { - std::fs::remove_file("nccl_id.txt")?; + std::fs::remove_file(comm_file)?; } println!("Rank {rank:?} spawned"); - let device = Device::new_cuda(i)?; + let device = Device::new_cuda(rank)?; let cache = model::Cache::new(dtype, &config, &device)?; println!("building the model"); @@ -210,14 +191,24 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); println!("starting the inference loop"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); + let temperature = if args.temperature <= 0. { + None + } else { + Some(args.temperature) + }; + let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); let mut new_tokens = vec![]; - let start_gen = std::time::Instant::now(); + let mut start_gen = std::time::Instant::now(); let mut index_pos = 0; for index in 0..args.sample_len { - let start_gen = std::time::Instant::now(); + // Only start timing at the second token as processing the first token waits for all the + // weights to be loaded in an async way. + if index == 1 { + start_gen = std::time::Instant::now() + }; let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; @@ -229,24 +220,18 @@ fn main() -> Result<()> { tokens.push(next_token); new_tokens.push(next_token); if rank == 0 { - println!("> {:?}", start_gen.elapsed()); - println!( - "{} token: {} '{}'", - index + 1, - next_token, - tokenizer.decode(&[next_token], true).map_err(E::msg)? - ); + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } } } - let dt = start_gen.elapsed(); if rank == 0 { + let dt = start_gen.elapsed(); println!( - "{} tokens generated ({} token/s)\n----\n{}\n----", + "\n\n{} tokens generated ({} token/s)\n", args.sample_len, - args.sample_len as f64 / dt.as_secs_f64(), - tokenizer - .decode(new_tokens.as_slice(), true) - .map_err(E::msg)? + (args.sample_len - 1) as f64 / dt.as_secs_f64(), ); } Ok(()) diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index bb5c2368..414b1242 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -2,7 +2,7 @@ use candle::backend::BackendStorage; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; use candle_nn::{Embedding, Linear, Module, RmsNorm}; use cudarc::nccl::safe::{Comm, ReduceOp}; -use half::f16; +use half::{bf16, f16}; use serde::Deserialize; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -58,15 +58,31 @@ impl CustomOp1 for AllReduce { use candle::cuda_backend::WrapErr; let elem_count = l.shape().elem_count(); let dev = s.device().clone(); - let s = s.as_cuda_slice::()?; - // let s = match l.contiguous_offsets() { - // None => Err(Error::Wrapped("input has to be contiguous".into()))?, - // Some((o1, o2)) => s.slice(o1..o2), - // }; - let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; - self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap(); - let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); - Ok((dst, l.shape().clone())) + match s.dtype() { + DType::BF16 => { + let s = s.as_cuda_slice::()?; + // let s = match l.contiguous_offsets() { + // None => Err(Error::Wrapped("input has to be contiguous".into()))?, + // Some((o1, o2)) => s.slice(o1..o2), + // }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap(); + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); + Ok((dst, l.shape().clone())) + } + DType::F16 => { + let s = s.as_cuda_slice::()?; + // let s = match l.contiguous_offsets() { + // None => Err(Error::Wrapped("input has to be contiguous".into()))?, + // Some((o1, o2)) => s.slice(o1..o2), + // }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap(); + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); + Ok((dst, l.shape().clone())) + } + dtype => candle::bail!("unsupported dtype {dtype:?}"), + } } } @@ -161,7 +177,6 @@ impl Cache { .matmul(&theta.reshape((1, theta.elem_count()))?)?; // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 - let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; let cos = idx_theta.cos()?.to_dtype(dtype)?; let sin = idx_theta.sin()?.to_dtype(dtype)?; Ok(Self { @@ -197,16 +212,10 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { - let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?; + let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; - let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; - Ok(rope) + candle_nn::rotary_emb::rope(x, &cos, &sin) } fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { @@ -232,13 +241,16 @@ impl CausalSelfAttention { let q = q .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let k = k .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let mut v = v .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let q = self.apply_rotary_emb(&q, index_pos)?; let mut k = self.apply_rotary_emb(&k, index_pos)?; @@ -278,16 +290,7 @@ impl CausalSelfAttention { fn repeat_kv(&self, x: Tensor) -> Result { let n_rep = self.num_attention_heads / self.num_key_value_heads; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?; - let x = x - .unsqueeze(2)? - .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?; - Ok(x) - } + candle_transformers::utils::repeat_kv(x, n_rep) } fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result {