mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
TP sharding v2
This commit is contained in:
@ -19,6 +19,8 @@ serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -40,3 +42,8 @@ default = []
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["dep:cudarc", "dep:half"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
required-features = ["cuda", "nccl"]
|
||||
|
266
candle-examples/examples/llama_multiprocess/main.rs
Normal file
266
candle-examples/examples/llama_multiprocess/main.rs
Normal file
@ -0,0 +1,266 @@
|
||||
// An implementation of LLaMA https://github.com/facebookresearch/llama
|
||||
//
|
||||
// This is based on nanoGPT in a similar way to:
|
||||
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
|
||||
//
|
||||
// The tokenizer config can be retrieved from:
|
||||
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
|
||||
//
|
||||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
use std::rc::Rc;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
Or whether he be 'scaped away or no
|
||||
From Clifford's and Northumberland's pursuit:
|
||||
Had he been ta'en, we should have heard the news;
|
||||
Had he been slain, we should have heard the news;
|
||||
Or had he 'scaped, methinks we should have heard
|
||||
The happy tidings of his good escape.
|
||||
How fares my brother? why is he so sad?
|
||||
|
||||
RICHARD:
|
||||
I cannot joy, until I be resolved
|
||||
Where our right valiant father is become.
|
||||
I saw him in the battle range about;
|
||||
And watch'd him how he singled Clifford forth.
|
||||
Methought he bore him in the thickest troop
|
||||
As doth a lion in a herd of neat;
|
||||
Or as a bear, encompass'd round with dogs,
|
||||
Who having pinch'd a few and made them cry,
|
||||
The rest stand all aloof, and bark at him.
|
||||
So fared our father with his enemies;
|
||||
So fled his enemies my warlike father:
|
||||
Methinks, 'tis prize enough to be his son.
|
||||
See how the morning opes her golden gates,
|
||||
And takes her farewell of the glorious sun!
|
||||
How well resembles it the prime of youth,
|
||||
Trimm'd like a younker prancing to his love!
|
||||
|
||||
EDWARD:
|
||||
Dazzle mine eyes, or do I see three suns?
|
||||
|
||||
RICHARD:
|
||||
Three glorious suns, each one a perfect sun;
|
||||
Not separated with the racking clouds,
|
||||
But sever'd in a pale clear-shining sky.
|
||||
See, see! they join, embrace, and seem to kiss,
|
||||
As if they vow'd some league inviolable:
|
||||
Now are they but one lamp, one light, one sun.
|
||||
In this the heaven figures some event.
|
||||
|
||||
EDWARD:
|
||||
'Tis wondrous strange, the like yet never heard of.
|
||||
I think it cites us, brother, to the field,
|
||||
That we, the sons of brave Plantagenet,
|
||||
Each one already blazing by our meeds,
|
||||
Should notwithstanding join our lights together
|
||||
And over-shine the earth as this the world.
|
||||
Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
num_shards: usize,
|
||||
|
||||
#[arg(long)]
|
||||
rank: Option<usize>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
#[arg(long)]
|
||||
no_kv_cache: bool,
|
||||
|
||||
/// The initial prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use f32 computations rather than f16.
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v2: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let config = Config::config_7b();
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = args.model_id.unwrap_or_else(|| {
|
||||
if args.v2 {
|
||||
"meta-llama/Llama-2-7b-hf".to_string()
|
||||
} else {
|
||||
"Narsil/amall-7b".to_string()
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let repo = Repo::new(model_id, RepoType::Model);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
if args.rank.is_none() {
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait().unwrap();
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let i = args.rank.unwrap();
|
||||
let num_shards = args.num_shards;
|
||||
let rank = i;
|
||||
// Primitive IPC
|
||||
let id = if rank == 0 {
|
||||
let id = Id::new().unwrap();
|
||||
std::fs::File::create("nccl_id.txt.tmp")?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
|
||||
.unwrap();
|
||||
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
|
||||
id
|
||||
} else {
|
||||
let path = std::path::PathBuf::from("nccl_id.txt");
|
||||
while !path.exists() {
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
}
|
||||
let data = std::fs::read("nccl_id.txt")?;
|
||||
let internal: [i8; 128] = data
|
||||
.into_iter()
|
||||
.map(|i| i as i8)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
let id: Id = Id::uninit(internal);
|
||||
id
|
||||
};
|
||||
let device = CudaDevice::new(i)?;
|
||||
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
|
||||
if rank == 0 {
|
||||
std::fs::remove_file("nccl_id.txt")?;
|
||||
}
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(i)?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(new_tokens, true).map_err(E::msg)?
|
||||
);
|
||||
Ok(())
|
||||
}
|
465
candle-examples/examples/llama_multiprocess/model.rs
Normal file
465
candle-examples/examples/llama_multiprocess/model.rs
Normal file
@ -0,0 +1,465 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||
use cudarc::driver::safe::CudaSlice;
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn new(linear: Linear) -> Self {
|
||||
Self { linear }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.linear.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct TensorParallelRowLinear {
|
||||
linear: Linear,
|
||||
comm: Rc<Comm>,
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
Ok(x.clone())
|
||||
// let n = x.shape().elem_count();
|
||||
// let cuda_slice: CudaSlice<f16> = x.try_into()?;
|
||||
// let dev = cuda_slice.device();
|
||||
// let mut slice_receive = dev.alloc_zeros(n).unwrap();
|
||||
// comm.all_reduce(cuda_slice, &mut slice_receive, &ReduceOp::Sum).unwrap();
|
||||
// Tensor::from_raw_storage(slice_receive, x.shape())
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||
Self { linear, comm }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.linear.forward(x)?;
|
||||
all_reduce_sum(&x, &self.comm)
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
|
||||
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weights: Vec<_> = prefixes
|
||||
.iter()
|
||||
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
|
||||
.collect();
|
||||
let weight = Tensor::cat(&weights, 0)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||
Ok(Self::new(Linear::new(weight, None), comm.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub n_embd: usize,
|
||||
pub n_key_value_head: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b() -> Self {
|
||||
Self {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
n_layer: 32,
|
||||
n_head: 32,
|
||||
n_embd: 4096,
|
||||
n_key_value_head: 32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.n_embd / config.n_head;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok(Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||
device: device.clone(),
|
||||
cos,
|
||||
sin,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
// TODO: If we support bool or u8 tensors, this would be better.
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get(size, "weight")?;
|
||||
Ok(Self::new(scale))
|
||||
}
|
||||
|
||||
fn new(scale: Tensor) -> Self {
|
||||
Self { scale }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let in_dtype = x.dtype();
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
||||
let size = self.scale.shape().r1()?;
|
||||
let scale = self
|
||||
.scale
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((b_sz, seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(in_dtype)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
qkv_proj: TensorParallelColumnLinear,
|
||||
o_proj: TensorParallelRowLinear,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let (b_sz, seq_len, _) = x.shape().r3()?;
|
||||
|
||||
let qkv = self.qkv_proj.forward(x)?;
|
||||
let n_embd = self.n_head * self.head_dim;
|
||||
|
||||
let q = qkv.i((.., .., ..self.n_head * self.head_dim))?;
|
||||
let k = qkv.i((
|
||||
..,
|
||||
..,
|
||||
self.n_head * self.head_dim
|
||||
..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim,
|
||||
))?;
|
||||
let v = qkv.i((
|
||||
..,
|
||||
..,
|
||||
self.n_head * self.head_dim + self.n_key_value_head * self.head_dim..,
|
||||
))?;
|
||||
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.to_dtype(DType::F32)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
if k_seq_len > MAX_SEQ_LEN {
|
||||
k = k
|
||||
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
let v_seq_len = v.dims()[1];
|
||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||
v = v
|
||||
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = y.to_dtype(x_dtype)?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let size_in = cfg.hidden_size;
|
||||
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||
|
||||
let qkv_proj = TensorParallelColumnLinear::load_multi(
|
||||
vb.clone(),
|
||||
&["q_proj", "k_proj", "v_proj"],
|
||||
comm.clone(),
|
||||
)?;
|
||||
let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?;
|
||||
Ok(Self {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
n_head: cfg.n_head / comm.world_size(),
|
||||
n_key_value_head: cfg.n_key_value_head / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.n_head,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let h_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
}
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
|
||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||
Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.shape().r2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||
.map(|i| {
|
||||
Block::load(
|
||||
vb.pp(&format!("model.layers.{i}")),
|
||||
cache,
|
||||
cfg,
|
||||
comm.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user