TP sharding v2

This commit is contained in:
Nicolas Patry
2023-07-21 15:10:51 +00:00
parent 209f06d7c3
commit 1735e4831e
9 changed files with 833 additions and 18 deletions

View File

@ -19,8 +19,10 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes # Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] } # cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] } cudarc = { git = "https://github.com/coreylowman/cudarc.git", features = ["f16", "nccl"] }
# TODO: Switch back to the official gemm implementation if we manage to upstream the changes. # TODO: Switch back to the official gemm implementation once the following are available.
# https://github.com/sarah-ek/gemm/pull/8.
# https://github.com/sarah-ek/gemm/pull/9.
gemm = { git = "https://github.com/LaurentMazare/gemm.git" } gemm = { git = "https://github.com/LaurentMazare/gemm.git" }
hf-hub = "0.1.3" hf-hub = "0.1.3"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }

View File

@ -1,6 +1,7 @@
use crate::{DType, Device, Error, Result, Tensor, WithDType}; use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::slice::SliceIterator;
use safetensors::tensor as st; use safetensors::tensor as st;
pub use safetensors::tensor::SafeTensors; use safetensors::tensor::{Dtype, SafeTensors};
use std::borrow::Cow; use std::borrow::Cow;
impl From<DType> for st::Dtype { impl From<DType> for st::Dtype {
@ -63,15 +64,15 @@ impl Tensor {
} }
} }
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> { fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
let v = view.data();
let size_in_bytes = T::DTYPE.size_in_bytes(); let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = v.len() / size_in_bytes; let elem_count = data.len() / size_in_bytes;
if (v.as_ptr() as usize) % size_in_bytes == 0 { if (data.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this // SAFETY This is safe because we just checked that this
// was correctly aligned. // was correctly aligned.
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) }; let data: &[T] =
Tensor::from_slice(data, view.shape(), device) unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, shape, device)
} else { } else {
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
@ -81,13 +82,17 @@ fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<
// We're downgrading the `c` pointer from T to u8, which removes alignment // We're downgrading the `c` pointer from T to u8, which removes alignment
// constraints. // constraints.
unsafe { unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
c.set_len(elem_count) c.set_len(elem_count)
} }
Tensor::from_slice(&c, view.shape(), device) Tensor::from_slice(&c, shape, device)
} }
} }
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_slice::<T>(view.data(), view.shape(), device)
}
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> { fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
let size_in_bytes = T::DTYPE.size_in_bytes(); let size_in_bytes = T::DTYPE.size_in_bytes();
let length = vs.len() * size_in_bytes; let length = vs.len() * size_in_bytes;
@ -112,6 +117,26 @@ impl<'a> Load for st::TensorView<'a> {
} }
} }
impl Tensor {
pub fn from_safetensors_slice(
iterator: SliceIterator,
dtype: Dtype,
shape: &[usize],
device: &Device,
) -> Result<Self> {
let data: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
match dtype {
st::Dtype::U8 => convert_slice::<u8>(&data, shape, device),
st::Dtype::U32 => convert_slice::<u8>(&data, shape, device),
st::Dtype::BF16 => convert_slice::<half::bf16>(&data, shape, device),
st::Dtype::F16 => convert_slice::<half::f16>(&data, shape, device),
st::Dtype::F32 => convert_slice::<f32>(&data, shape, device),
st::Dtype::F64 => convert_slice::<f64>(&data, shape, device),
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
}
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> { pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() { match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device), st::Dtype::U8 => convert_::<u8>(view, device),

View File

@ -19,6 +19,8 @@ serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -40,3 +42,8 @@ default = []
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"] flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["dep:cudarc", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl"]

View File

@ -0,0 +1,266 @@
// An implementation of LLaMA https://github.com/facebookresearch/llama
//
// This is based on nanoGPT in a similar way to:
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
//
// The tokenizer config can be retrieved from:
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
//
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::rc::Rc;
mod model;
use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no
From Clifford's and Northumberland's pursuit:
Had he been ta'en, we should have heard the news;
Had he been slain, we should have heard the news;
Or had he 'scaped, methinks we should have heard
The happy tidings of his good escape.
How fares my brother? why is he so sad?
RICHARD:
I cannot joy, until I be resolved
Where our right valiant father is become.
I saw him in the battle range about;
And watch'd him how he singled Clifford forth.
Methought he bore him in the thickest troop
As doth a lion in a herd of neat;
Or as a bear, encompass'd round with dogs,
Who having pinch'd a few and made them cry,
The rest stand all aloof, and bark at him.
So fared our father with his enemies;
So fled his enemies my warlike father:
Methinks, 'tis prize enough to be his son.
See how the morning opes her golden gates,
And takes her farewell of the glorious sun!
How well resembles it the prime of youth,
Trimm'd like a younker prancing to his love!
EDWARD:
Dazzle mine eyes, or do I see three suns?
RICHARD:
Three glorious suns, each one a perfect sun;
Not separated with the racking clouds,
But sever'd in a pale clear-shining sky.
See, see! they join, embrace, and seem to kiss,
As if they vow'd some league inviolable:
Now are they but one lamp, one light, one sun.
In this the heaven figures some event.
EDWARD:
'Tis wondrous strange, the like yet never heard of.
I think it cites us, brother, to the field,
That we, the sons of brave Plantagenet,
Each one already blazing by our meeds,
Should notwithstanding join our lights together
And over-shine the earth as this the world.
Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
num_shards: usize,
#[arg(long)]
rank: Option<usize>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,
/// Disable the key-value cache.
#[arg(long)]
no_kv_cache: bool,
/// The initial prompt.
#[arg(long)]
prompt: Option<String>,
/// Use f32 computations rather than f16.
#[arg(long)]
use_f32: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
v2: bool,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
let config = Config::config_7b();
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| {
if args.v2 {
"meta-llama/Llama-2-7b-hf".to_string()
} else {
"Narsil/amall-7b".to_string()
}
});
println!("loading the model weights from {model_id}");
let repo = Repo::new(model_id, RepoType::Model);
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(&repo, rfilename)?;
filenames.push(filename);
}
if args.rank.is_none() {
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait().unwrap();
}
return Ok(());
}
let i = args.rank.unwrap();
let num_shards = args.num_shards;
let rank = i;
// Primitive IPC
let id = if rank == 0 {
let id = Id::new().unwrap();
std::fs::File::create("nccl_id.txt.tmp")?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
.unwrap();
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
id
} else {
let path = std::path::PathBuf::from("nccl_id.txt");
while !path.exists() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
let data = std::fs::read("nccl_id.txt")?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
.collect::<Vec<_>>()
.try_into()
.unwrap();
let id: Id = Id::uninit(internal);
id
};
let device = CudaDevice::new(i)?;
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
if rank == 0 {
std::fs::remove_file("nccl_id.txt")?;
}
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let context_size = if cache.use_kv_cache && index > 0 {
1
} else {
tokens.len()
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
Ok(())
}

View File

@ -0,0 +1,465 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder};
use cudarc::driver::safe::CudaSlice;
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
struct TensorParallelColumnLinear {
linear: Linear,
}
impl TensorParallelColumnLinear {
fn new(linear: Linear) -> Self {
Self { linear }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.linear.forward(x)
}
}
struct TensorParallelRowLinear {
linear: Linear,
comm: Rc<Comm>,
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
Ok(x.clone())
// let n = x.shape().elem_count();
// let cuda_slice: CudaSlice<f16> = x.try_into()?;
// let dev = cuda_slice.device();
// let mut slice_receive = dev.alloc_zeros(n).unwrap();
// comm.all_reduce(cuda_slice, &mut slice_receive, &ReduceOp::Sum).unwrap();
// Tensor::from_raw_storage(slice_receive, x.shape())
}
impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
Self { linear, comm }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.linear.forward(x)?;
all_reduce_sum(&x, &self.comm)
}
}
impl TensorParallelColumnLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 0, rank, size)?;
Ok(Self::new(Linear::new(weight, None)))
}
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weights: Vec<_> = prefixes
.iter()
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
.collect();
let weight = Tensor::cat(&weights, 0)?;
Ok(Self::new(Linear::new(weight, None)))
}
}
impl TensorParallelRowLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 1, rank, size)?;
Ok(Self::new(Linear::new(weight, None), comm.clone()))
}
}
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub n_layer: usize,
pub n_head: usize,
pub n_embd: usize,
pub n_key_value_head: usize,
}
impl Config {
pub fn config_7b() -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
}
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
device: Device,
}
impl Cache {
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
device: device.clone(),
cos,
sin,
})
}
fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
// TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
Ok(mask)
}
}
}
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
struct RmsNorm {
scale: Tensor,
}
impl RmsNorm {
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(size, "weight")?;
Ok(Self::new(scale))
}
fn new(scale: Tensor) -> Self {
Self { scale }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?;
Ok(x)
}
}
struct CausalSelfAttention {
qkv_proj: TensorParallelColumnLinear,
o_proj: TensorParallelRowLinear,
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let x_dtype = x.dtype();
let (b_sz, seq_len, _) = x.shape().r3()?;
let qkv = self.qkv_proj.forward(x)?;
let n_embd = self.n_head * self.head_dim;
let q = qkv.i((.., .., ..self.n_head * self.head_dim))?;
let k = qkv.i((
..,
..,
self.n_head * self.head_dim
..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim,
))?;
let v = qkv.i((
..,
..,
self.n_head * self.head_dim + self.n_key_value_head * self.head_dim..,
))?;
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?
.to_dtype(DType::F32)?;
let k = k
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?
.to_dtype(DType::F32)?;
let mut v = v
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?
.to_dtype(DType::F32)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
if k_seq_len > MAX_SEQ_LEN {
k = k
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > 2 * MAX_SEQ_LEN {
v = v
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
}
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = y.to_dtype(x_dtype)?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_key_value_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
let qkv_proj = TensorParallelColumnLinear::load_multi(
vb.clone(),
&["q_proj", "k_proj", "v_proj"],
comm.clone(),
)?;
let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?;
Ok(Self {
qkv_proj,
o_proj,
n_head: cfg.n_head / comm.world_size(),
n_key_value_head: cfg.n_key_value_head / comm.world_size(),
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
})
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
struct Mlp {
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,
c_proj: TensorParallelRowLinear,
}
impl Mlp {
fn new(
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,
c_proj: TensorParallelRowLinear,
) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let h_size = cfg.hidden_size;
let i_size = cfg.intermediate_size;
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
}
}
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
rms_2,
mlp,
}
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
post_attention_layernorm,
mlp,
))
}
}
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
}
impl Llama {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,
ln_f,
lm_head,
}
}
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.shape().r2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layer)
.map(|i| {
Block::load(
vb.pp(&format!("model.layers.{i}")),
cache,
cfg,
comm.clone(),
)
.unwrap()
})
.collect();
Ok(Self::new(wte, blocks, norm, lm_head))
}
}

View File

@ -14,6 +14,7 @@ readme = "README.md"
candle = { path = "../candle-core" } candle = { path = "../candle-core" }
thiserror = { workspace = true } thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
safetensors = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -1,7 +1,6 @@
use candle::{ use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
safetensors::{Load, SafeTensors}, use safetensors::slice::IndexOp;
DType, Device, Error, Result, Shape, Tensor, use safetensors::tensor::SafeTensors;
};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -71,7 +70,7 @@ impl<'a> TensorData<'a> {
#[derive(Clone)] #[derive(Clone)]
pub struct VarBuilder<'a> { pub struct VarBuilder<'a> {
data: Arc<TensorData<'a>>, data: Arc<TensorData<'a>>,
path: Vec<String>, pub path: Vec<String>,
} }
impl<'a> VarBuilder<'a> { impl<'a> VarBuilder<'a> {
@ -137,6 +136,55 @@ impl<'a> VarBuilder<'a> {
} }
impl<'a> VarBuilder<'a> { impl<'a> VarBuilder<'a> {
pub fn get_sharded(
&self,
tensor_name: &str,
dim: usize,
rank: usize,
world_size: usize,
) -> Result<Tensor> {
let data = self.data.as_ref();
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
let tensor = match &self.data.tensors {
Tensors::SafeTensorWithRouting {
routing,
safetensors,
} => {
let index = routing.get(&path).ok_or_else(|| {
Error::CannotFindTensor {
path: path.to_string(),
}
.bt()
})?;
let view = safetensors[*index].tensor(&path)?;
let dtype = view.dtype();
let mut shape = view.shape().to_vec();
let size = shape[dim];
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
let iterator = if dim == 0 {
view.slice(start..stop).unwrap()
} else if dim == 1 {
view.slice((.., start..stop)).unwrap()
} else {
unimplemented!("Get sharded on dimensions != 0 or 1");
};
shape[dim] = block_size;
Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
}
_ => unimplemented!(),
};
Ok(tensor)
}
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> { pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let data = self.data.as_ref(); let data = self.data.as_ref();
let s: Shape = s.into(); let s: Shape = s.into();

View File

@ -15,6 +15,7 @@ candle = { path = "../../candle-core" }
candle-nn = { path = "../../candle-nn" } candle-nn = { path = "../../candle-nn" }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }
safetensors = { workspace = true }
# App crates. # App crates.
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -236,11 +236,11 @@ impl Decoder {
let device = Device::Cpu; let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?; let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?; let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape()); console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?; let weights = safetensors::tensor::SafeTensors::deserialize(&md.weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config = Config::tiny_en(); let config = Config::tiny_en();
let whisper = Whisper::load(&vb, config)?; let whisper = Whisper::load(&vb, config)?;