mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Small cleanups to the llama multi-process example. (#2098)
This commit is contained in:
@ -1,15 +1,14 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::{bf16, f16};
|
||||
use serde::Deserialize;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
pub type Config = candle_transformers::models::llama::LlamaConfig;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
|
||||
|
||||
struct TensorParallelRowLinear {
|
||||
linear: Linear,
|
||||
comm: Rc<Comm>,
|
||||
all_reduce: AllReduce,
|
||||
}
|
||||
|
||||
struct AllReduce {
|
||||
@ -36,8 +35,6 @@ struct AllReduce {
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Sync for AllReduce {}
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Send for AllReduce {}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
todo!("implement allreduce for cpu is not necessary for single node");
|
||||
candle::bail!("AllReduce is never used on cpu")
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@ -56,47 +53,49 @@ impl CustomOp1 for AllReduce {
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use cudarc::driver::DeviceSlice;
|
||||
use half::{bf16, f16};
|
||||
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
match s.dtype() {
|
||||
let dst = match s.dtype() {
|
||||
DType::BF16 => {
|
||||
let s = s.as_cuda_slice::<bf16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F16 => {
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
||||
}
|
||||
};
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||
Self { linear, comm }
|
||||
let all_reduce = AllReduce { comm };
|
||||
Self { linear, all_reduce }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.linear.forward(x)?;
|
||||
all_reduce_sum(&x, &self.comm)
|
||||
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
|
||||
}
|
||||
}
|
||||
|
||||
@ -137,23 +136,6 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
@ -281,9 +263,7 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
.reshape((b_sz, seq_len, hidden_size))?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
@ -304,7 +284,7 @@ impl CausalSelfAttention {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
@ -318,18 +298,6 @@ struct Mlp {
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
@ -339,7 +307,11 @@ impl Mlp {
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
Ok(Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -430,10 +402,8 @@ impl Llama {
|
||||
cfg,
|
||||
comm.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user