Small cleanups to the llama multi-process example. (#2098)

This commit is contained in:
Laurent Mazare
2024-04-20 22:19:46 +02:00
committed by GitHub
parent dd78422701
commit 587ee3bb6f
4 changed files with 54 additions and 70 deletions

View File

@ -1,15 +1,14 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::{bf16, f16};
use serde::Deserialize;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
pub type Config = candle_transformers::models::llama::LlamaConfig;
struct TensorParallelColumnLinear {
linear: Linear,
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
struct TensorParallelRowLinear {
linear: Linear,
comm: Rc<Comm>,
all_reduce: AllReduce,
}
struct AllReduce {
@ -36,8 +35,6 @@ struct AllReduce {
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Send for AllReduce {}
impl CustomOp1 for AllReduce {
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
todo!("implement allreduce for cpu is not necessary for single node");
candle::bail!("AllReduce is never used on cpu")
}
#[cfg(feature = "cuda")]
@ -56,47 +53,49 @@ impl CustomOp1 for AllReduce {
l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::WrapErr;
use cudarc::driver::DeviceSlice;
use half::{bf16, f16};
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
match s.dtype() {
let dst = match s.dtype() {
DType::BF16 => {
let s = s.as_cuda_slice::<bf16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle::bail!("input has to be contiguous"),
};
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
self.comm
.all_reduce(s, &mut dst, &ReduceOp::Sum)
.map_err(candle::Error::debug)?;
candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle::bail!("input has to be contiguous"),
};
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
self.comm
.all_reduce(s, &mut dst, &ReduceOp::Sum)
.map_err(candle::Error::debug)?;
candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle::bail!("unsupported dtype {dtype:?}"),
}
};
Ok((dst, l.shape().clone()))
}
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
x.apply_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
Self { linear, comm }
let all_reduce = AllReduce { comm };
Self { linear, all_reduce }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.linear.forward(x)?;
all_reduce_sum(&x, &self.comm)
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
}
}
@ -137,23 +136,6 @@ impl TensorParallelRowLinear {
}
}
#[derive(Deserialize)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
}
fn default_rope() -> f32 {
10_000.0
}
#[derive(Clone)]
pub struct Cache {
#[allow(clippy::type_complexity)]
@ -281,9 +263,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
.transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
.reshape((b_sz, seq_len, hidden_size))?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
@ -304,7 +284,7 @@ impl CausalSelfAttention {
qkv_proj,
o_proj,
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
})
@ -318,18 +298,6 @@ struct Mlp {
}
impl Mlp {
fn new(
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,
c_proj: TensorParallelRowLinear,
) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
@ -339,7 +307,11 @@ impl Mlp {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
Ok(Self {
c_fc1,
c_fc2,
c_proj,
})
}
}
@ -430,10 +402,8 @@ impl Llama {
cfg,
comm.clone(),
)
.unwrap()
})
.collect();
.collect::<Result<Vec<_>>>()?;
Ok(Self::new(wte, blocks, norm, lm_head))
}
}