mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Small cleanups to the llama multi-process example. (#2098)
This commit is contained in:
@ -219,10 +219,14 @@ impl Error {
|
|||||||
Self::Wrapped(Box::new(err)).bt()
|
Self::Wrapped(Box::new(err)).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
pub fn msg(err: impl std::error::Error) -> Self {
|
||||||
Self::Msg(err.to_string()).bt()
|
Self::Msg(err.to_string()).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
||||||
|
Self::Msg(format!("{err:?}")).bt()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn bt(self) -> Self {
|
pub fn bt(self) -> Self {
|
||||||
let backtrace = std::backtrace::Backtrace::capture();
|
let backtrace = std::backtrace::Backtrace::capture();
|
||||||
match backtrace.status() {
|
match backtrace.status() {
|
||||||
|
@ -76,7 +76,7 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
dtype: Option<String>,
|
dtype: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long, default_value = "v3-8b")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
#[arg(long, default_value = "nccl_id.txt")]
|
#[arg(long, default_value = "nccl_id.txt")]
|
||||||
@ -219,6 +219,9 @@ fn main() -> Result<()> {
|
|||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
|
if Some(next_token) == config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
print!("{t}");
|
print!("{t}");
|
||||||
@ -226,6 +229,7 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
println!();
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||||
|
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||||
use half::{bf16, f16};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::MAX_SEQ_LEN;
|
use super::MAX_SEQ_LEN;
|
||||||
|
|
||||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
pub type Config = candle_transformers::models::llama::LlamaConfig;
|
||||||
|
|
||||||
struct TensorParallelColumnLinear {
|
struct TensorParallelColumnLinear {
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
|
|||||||
|
|
||||||
struct TensorParallelRowLinear {
|
struct TensorParallelRowLinear {
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
comm: Rc<Comm>,
|
all_reduce: AllReduce,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AllReduce {
|
struct AllReduce {
|
||||||
@ -36,8 +35,6 @@ struct AllReduce {
|
|||||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||||
/// But for this example purposes, this will work
|
/// But for this example purposes, this will work
|
||||||
unsafe impl Sync for AllReduce {}
|
unsafe impl Sync for AllReduce {}
|
||||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
|
||||||
/// But for this example purposes, this will work
|
|
||||||
unsafe impl Send for AllReduce {}
|
unsafe impl Send for AllReduce {}
|
||||||
|
|
||||||
impl CustomOp1 for AllReduce {
|
impl CustomOp1 for AllReduce {
|
||||||
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
todo!("implement allreduce for cpu is not necessary for single node");
|
candle::bail!("AllReduce is never used on cpu")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
@ -56,47 +53,49 @@ impl CustomOp1 for AllReduce {
|
|||||||
l: &Layout,
|
l: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
|
use cudarc::driver::DeviceSlice;
|
||||||
|
use half::{bf16, f16};
|
||||||
|
|
||||||
let elem_count = l.shape().elem_count();
|
let elem_count = l.shape().elem_count();
|
||||||
let dev = s.device().clone();
|
let dev = s.device().clone();
|
||||||
match s.dtype() {
|
let dst = match s.dtype() {
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
let s = s.as_cuda_slice::<bf16>()?;
|
let s = s.as_cuda_slice::<bf16>()?;
|
||||||
// let s = match l.contiguous_offsets() {
|
let s = match l.contiguous_offsets() {
|
||||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
Some((0, l)) if l == s.len() => s,
|
||||||
// Some((o1, o2)) => s.slice(o1..o2),
|
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||||
// };
|
};
|
||||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
self.comm
|
||||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||||
Ok((dst, l.shape().clone()))
|
.map_err(candle::Error::debug)?;
|
||||||
|
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
let s = s.as_cuda_slice::<f16>()?;
|
let s = s.as_cuda_slice::<f16>()?;
|
||||||
// let s = match l.contiguous_offsets() {
|
let s = match l.contiguous_offsets() {
|
||||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
Some((0, l)) if l == s.len() => s,
|
||||||
// Some((o1, o2)) => s.slice(o1..o2),
|
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||||
// };
|
};
|
||||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
self.comm
|
||||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||||
Ok((dst, l.shape().clone()))
|
.map_err(candle::Error::debug)?;
|
||||||
|
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||||
}
|
}
|
||||||
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
||||||
|
};
|
||||||
|
Ok((dst, l.shape().clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
|
||||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorParallelRowLinear {
|
impl TensorParallelRowLinear {
|
||||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||||
Self { linear, comm }
|
let all_reduce = AllReduce { comm };
|
||||||
|
Self { linear, all_reduce }
|
||||||
}
|
}
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = self.linear.forward(x)?;
|
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
|
||||||
all_reduce_sum(&x, &self.comm)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,23 +136,6 @@ impl TensorParallelRowLinear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
pub hidden_size: usize,
|
|
||||||
pub intermediate_size: usize,
|
|
||||||
pub vocab_size: usize,
|
|
||||||
pub num_hidden_layers: usize,
|
|
||||||
pub num_attention_heads: usize,
|
|
||||||
pub num_key_value_heads: usize,
|
|
||||||
pub rms_norm_eps: f64,
|
|
||||||
#[serde(default = "default_rope")]
|
|
||||||
pub rope_theta: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_rope() -> f32 {
|
|
||||||
10_000.0
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
@ -281,9 +263,7 @@ impl CausalSelfAttention {
|
|||||||
let v = v.transpose(1, 2)?;
|
let v = v.transpose(1, 2)?;
|
||||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||||
.transpose(1, 2)?;
|
.reshape((b_sz, seq_len, hidden_size))?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
|
||||||
let y = self.o_proj.forward(&y)?;
|
let y = self.o_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
@ -304,7 +284,7 @@ impl CausalSelfAttention {
|
|||||||
qkv_proj,
|
qkv_proj,
|
||||||
o_proj,
|
o_proj,
|
||||||
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||||
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
|
||||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
})
|
})
|
||||||
@ -318,18 +298,6 @@ struct Mlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn new(
|
|
||||||
c_fc1: TensorParallelColumnLinear,
|
|
||||||
c_fc2: TensorParallelColumnLinear,
|
|
||||||
c_proj: TensorParallelRowLinear,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
c_fc1,
|
|
||||||
c_fc2,
|
|
||||||
c_proj,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||||
self.c_proj.forward(&x)
|
self.c_proj.forward(&x)
|
||||||
@ -339,7 +307,11 @@ impl Mlp {
|
|||||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
Ok(Self {
|
||||||
|
c_fc1,
|
||||||
|
c_fc2,
|
||||||
|
c_proj,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,10 +402,8 @@ impl Llama {
|
|||||||
cfg,
|
cfg,
|
||||||
comm.clone(),
|
comm.clone(),
|
||||||
)
|
)
|
||||||
.unwrap()
|
|
||||||
})
|
})
|
||||||
.collect();
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,12 @@ pub struct LlamaConfig {
|
|||||||
pub eos_token_id: Option<u32>,
|
pub eos_token_id: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LlamaConfig {
|
||||||
|
pub fn num_key_value_heads(&self) -> usize {
|
||||||
|
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn default_rope() -> f32 {
|
fn default_rope() -> f32 {
|
||||||
10_000.0
|
10_000.0
|
||||||
}
|
}
|
||||||
@ -32,7 +38,7 @@ impl LlamaConfig {
|
|||||||
vocab_size: self.vocab_size,
|
vocab_size: self.vocab_size,
|
||||||
num_hidden_layers: self.num_hidden_layers,
|
num_hidden_layers: self.num_hidden_layers,
|
||||||
num_attention_heads: self.num_attention_heads,
|
num_attention_heads: self.num_attention_heads,
|
||||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
num_key_value_heads: self.num_key_value_heads(),
|
||||||
rms_norm_eps: self.rms_norm_eps,
|
rms_norm_eps: self.rms_norm_eps,
|
||||||
rope_theta: self.rope_theta,
|
rope_theta: self.rope_theta,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
|
Reference in New Issue
Block a user