Multiprocess/multi-GPU support for llama 3. (#2092)

* Multiprocess/multi-GPU support for llama 3.

* Modernize the mp example a bit.
This commit is contained in:
Laurent Mazare
2024-04-20 12:49:21 +02:00
committed by GitHub
parent b45c710dbf
commit c97d639fa0
2 changed files with 123 additions and 135 deletions

View File

@ -2,7 +2,7 @@ use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use half::{bf16, f16};
use serde::Deserialize;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
@ -58,15 +58,31 @@ impl CustomOp1 for AllReduce {
use candle::cuda_backend::WrapErr;
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
match s.dtype() {
DType::BF16 => {
let s = s.as_cuda_slice::<bf16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
}
DType::F16 => {
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
}
dtype => candle::bail!("unsupported dtype {dtype:?}"),
}
}
}
@ -161,7 +177,6 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
@ -197,16 +212,10 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
candle_nn::rotary_emb::rope(x, &cos, &sin)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
@ -232,13 +241,16 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
@ -278,16 +290,7 @@ impl CausalSelfAttention {
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.num_attention_heads / self.num_key_value_heads;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
candle_transformers::utils::repeat_kv(x, n_rep)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {