From 872c3f14b0c5ead35d031a9881c5944700508756 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 17 Oct 2023 16:06:48 +0100 Subject: [PATCH] Add the MPT model. (#1114) * Add the MPT model. * Add ffn and block. * Forward pass for the mpt block. * Repeat-kv. --- candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/mpt.rs | 202 ++++++++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100644 candle-transformers/src/models/mpt.rs diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index aa9ea81a..8a02e2da 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -7,6 +7,7 @@ pub mod falcon; pub mod llama; pub mod mistral; pub mod mixformer; +pub mod mpt; pub mod quantized_llama; pub mod quantized_mistral; pub mod quantized_mixformer; diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs new file mode 100644 index 00000000..e11a9a75 --- /dev/null +++ b/candle-transformers/src/models/mpt.rs @@ -0,0 +1,202 @@ +#![allow(unused)] +use crate::models::with_tracing::{linear, Embedding as E, Linear}; +/// MPT model used by replit-code-v1_5-3b +/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder}; + +// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/configuration_mpt.py +#[derive(Debug, Clone, PartialEq)] +pub struct Config { + pub(crate) d_model: usize, + pub(crate) n_heads: usize, + pub(crate) n_layers: usize, + pub(crate) expansion_ratio: usize, + pub(crate) max_seq_len: usize, + pub(crate) vocab_size: usize, + pub(crate) kv_n_heads: usize, + // pub(crate) attn_config: AttnConfig, +} + +impl Config { + pub fn replit_code_v1_5_3b() -> Self { + Self { + d_model: 3072, + n_heads: 24, + n_layers: 32, + expansion_ratio: 4, + max_seq_len: 4096, + vocab_size: 32768, + kv_n_heads: 8, + } + } +} + +#[derive(Debug)] +struct GroupedQueryAttention { + wqkv: Linear, + out_proj: Linear, + kv_cache: Option<(Tensor, Tensor)>, + softmax_scale: f64, + head_dim: usize, + d_model: usize, + n_heads: usize, + kv_n_heads: usize, + span: tracing::Span, +} + +impl GroupedQueryAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads; + let wqkv = linear(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?; + let head_dim = cfg.d_model / cfg.n_heads; + let softmax_scale = 1f64 / (head_dim as f64).sqrt(); + let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?; + Ok(Self { + wqkv, + out_proj, + kv_cache: None, + softmax_scale, + head_dim, + d_model: cfg.d_model, + n_heads: cfg.n_heads, + kv_n_heads: cfg.kv_n_heads, + span: tracing::span!(tracing::Level::TRACE, "gqa"), + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let (b_size, seq_len, _n_embd) = xs.dims3()?; + let qkv = self.wqkv.forward(xs)?; + let query = qkv.narrow(2, 0, self.d_model)?; + let kv_size = self.kv_n_heads * self.head_dim; + let key = qkv.narrow(2, self.d_model, kv_size)?; + let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?; + // scaled_multihead_dot_product_attention + let query = query + .reshape((b_size, seq_len, self.n_heads, ()))? + .transpose(1, 2)?; // b,h,s,d + let key = key + .reshape((b_size, seq_len, self.kv_n_heads, ()))? + .permute((0, 2, 3, 1))?; // b,h,d,s + let value = value + .reshape((b_size, seq_len, self.kv_n_heads, ()))? + .transpose(1, 2)?; // b,h,s,d + let (key, value) = match &self.kv_cache { + None => (key, value), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &key], 3)?; + let v = Tensor::cat(&[prev_v, &value], 2)?; + (k, v) + } + }; + let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?; + let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?; + let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; + // TODO: attn_bias, alibi + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights + .matmul(&value)? + .transpose(1, 2)? + .flatten_from(D::Minus2)?; + attn_output.apply(&self.out_proj) + } +} + +// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). +// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to +// (batch, num_attention_heads, seqlen, head_dim) +fn repeat_kv(xs: Tensor, n_rep: usize) -> Result { + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } +} + +#[derive(Debug)] +struct Ffn { + up_proj: Linear, + down_proj: Linear, +} + +impl Ffn { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden = cfg.d_model * cfg.expansion_ratio; + let down_proj = linear(cfg.d_model, hidden, vb.pp("down_proj"))?; + let up_proj = linear(hidden, cfg.d_model, vb.pp("up_proj"))?; + Ok(Self { up_proj, down_proj }) + } +} + +impl Module for Ffn { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj) + } +} + +#[derive(Debug)] +struct MPTBlock { + norm1: LayerNorm, // Do we need the low-precision variant? + attn: GroupedQueryAttention, + norm2: LayerNorm, + ffn: Ffn, +} + +impl MPTBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let norm1 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_1"))?; + let norm2 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_2"))?; + let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?; + let ffn = Ffn::new(cfg, vb.pp("ffn"))?; + Ok(Self { + norm1, + attn, + norm2, + ffn, + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = xs.apply(&self.norm1)?; + let xs = self.attn.forward(&xs, mask)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.norm2)?.apply(&self.ffn); + xs + residual + } +} + +#[derive(Debug)] +struct Model { + wte: candle_nn::Embedding, + blocks: Vec, + norm_f: LayerNorm, +} + +impl Model { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let wte = candle_nn::embedding(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?; + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(cfg.n_layers); + for i in 0..cfg.n_layers { + let block = MPTBlock::new(cfg, vb_b.pp(i))?; + blocks.push(block) + } + let norm_f = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + Ok(Self { + wte, + blocks, + norm_f, + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + todo!() + } +}