mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Small cleanups to the llama multi-process example. (#2098)
This commit is contained in:
@ -219,10 +219,14 @@ impl Error {
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
pub fn msg(err: impl std::error::Error) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
}
|
||||
|
||||
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
||||
Self::Msg(format!("{err:?}")).bt()
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
let backtrace = std::backtrace::Backtrace::capture();
|
||||
match backtrace.status() {
|
||||
|
@ -76,7 +76,7 @@ struct Args {
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
#[arg(long, default_value = "v3-8b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long, default_value = "nccl_id.txt")]
|
||||
@ -219,6 +219,9 @@ fn main() -> Result<()> {
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
if Some(next_token) == config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
if rank == 0 {
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
@ -226,6 +229,7 @@ fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
if rank == 0 {
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
|
@ -1,15 +1,14 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::{bf16, f16};
|
||||
use serde::Deserialize;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
pub type Config = candle_transformers::models::llama::LlamaConfig;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
|
||||
|
||||
struct TensorParallelRowLinear {
|
||||
linear: Linear,
|
||||
comm: Rc<Comm>,
|
||||
all_reduce: AllReduce,
|
||||
}
|
||||
|
||||
struct AllReduce {
|
||||
@ -36,8 +35,6 @@ struct AllReduce {
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Sync for AllReduce {}
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Send for AllReduce {}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
todo!("implement allreduce for cpu is not necessary for single node");
|
||||
candle::bail!("AllReduce is never used on cpu")
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@ -56,47 +53,49 @@ impl CustomOp1 for AllReduce {
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use cudarc::driver::DeviceSlice;
|
||||
use half::{bf16, f16};
|
||||
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
match s.dtype() {
|
||||
let dst = match s.dtype() {
|
||||
DType::BF16 => {
|
||||
let s = s.as_cuda_slice::<bf16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F16 => {
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
||||
}
|
||||
};
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||
Self { linear, comm }
|
||||
let all_reduce = AllReduce { comm };
|
||||
Self { linear, all_reduce }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.linear.forward(x)?;
|
||||
all_reduce_sum(&x, &self.comm)
|
||||
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
|
||||
}
|
||||
}
|
||||
|
||||
@ -137,23 +136,6 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
@ -281,9 +263,7 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
.reshape((b_sz, seq_len, hidden_size))?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
@ -304,7 +284,7 @@ impl CausalSelfAttention {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
@ -318,18 +298,6 @@ struct Mlp {
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
@ -339,7 +307,11 @@ impl Mlp {
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
Ok(Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -430,10 +402,8 @@ impl Llama {
|
||||
cfg,
|
||||
comm.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
}
|
||||
}
|
||||
|
@ -20,6 +20,12 @@ pub struct LlamaConfig {
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
pub fn num_key_value_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
@ -32,7 +38,7 @@ impl LlamaConfig {
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
||||
num_key_value_heads: self.num_key_value_heads(),
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
|
Reference in New Issue
Block a user