diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 82c6c980..0f72b862 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::mpt::{Config, Model}; +use candle_transformers::models::mpt::{Config, Model as M}; +use candle_transformers::models::quantized_mpt::Model as Q; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -15,6 +16,20 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + M(M), + Q(Q), +} + +impl Model { + fn forward(&mut self, xs: &Tensor) -> candle::Result { + match self { + Self::M(model) => model.forward(xs), + Self::Q(model) => model.forward(xs), + } + } +} + struct TextGeneration { model: Model, device: Device, @@ -148,6 +163,9 @@ struct Args { #[arg(long)] revision: Option, + #[arg(long)] + quantized: bool, + #[arg(long)] weight_file: Option, @@ -206,16 +224,29 @@ fn main() -> Result<()> { }; let filename = match args.weight_file { Some(weight_file) => std::path::PathBuf::from(weight_file), - None => repo.get("model.safetensors")?, + None => { + if args.quantized { + repo.get("model-replit-code-v1_5-q4k.gguf")? + } else { + repo.get("model.safetensors")? + } + } }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); let config = Config::replit_code_v1_5_3b(); - let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = Model::new(&config, vb.pp("transformer"))?; + let (model, device) = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; + let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); + (model, Device::Cpu) + } else { + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; + let model = Model::M(M::new(&config, vb.pp("transformer"))?); + (model, device) + }; println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new( diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8a02e2da..fc57e732 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -11,6 +11,7 @@ pub mod mpt; pub mod quantized_llama; pub mod quantized_mistral; pub mod quantized_mixformer; +pub mod quantized_mpt; pub mod quantized_stable_lm; pub mod quantized_t5; pub mod segment_anything; diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index 300a1d57..0d91bf94 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -137,7 +137,7 @@ impl GroupedQueryAttention { // This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). // The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to // (batch, num_attention_heads, seqlen, head_dim) -fn repeat_kv(xs: Tensor, n_rep: usize) -> Result { +pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result { if n_rep == 1 { Ok(xs) } else { @@ -206,7 +206,7 @@ impl MPTBlock { } } -fn build_alibi_bias(cfg: &Config) -> Result { +pub(crate) fn build_alibi_bias(cfg: &Config) -> Result { let full = !cfg.is_causal(); let seq_len = cfg.max_seq_len; let alibi_bias = Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu)?; @@ -289,14 +289,14 @@ impl Model { } } -fn get_mask(size: usize, device: &Device) -> Result { +pub(crate) fn get_mask(size: usize, device: &Device) -> Result { let mask: Vec<_> = (0..size) .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) .collect(); Tensor::from_slice(&mask, (size, size), device) } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { +pub(crate) fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { let shape = mask.shape(); let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; let m = mask.where_cond(&on_true, on_false)?; diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs new file mode 100644 index 00000000..7586e4c0 --- /dev/null +++ b/candle-transformers/src/models/quantized_mpt.rs @@ -0,0 +1,201 @@ +use crate::quantized_nn::{layer_norm_no_bias, linear_no_bias, Embedding, Linear}; +pub use crate::quantized_var_builder::VarBuilder; +/// MPT model used by replit-code-v1_5-3b +/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py +use candle::{IndexOp, Module, Result, Tensor, D}; +use candle_nn::LayerNorm; + +pub use super::mpt::Config; + +#[derive(Debug)] +struct GroupedQueryAttention { + wqkv: Linear, + out_proj: Linear, + kv_cache: Option<(Tensor, Tensor)>, + softmax_scale: f64, + head_dim: usize, + d_model: usize, + n_heads: usize, + kv_n_heads: usize, + attn_bias: Tensor, + span: tracing::Span, +} + +impl GroupedQueryAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let head_dim = cfg.d_model / cfg.n_heads; + let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim; + let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?; + let softmax_scale = 1f64 / (head_dim as f64).sqrt(); + let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?; + let attn_bias = super::mpt::build_alibi_bias(cfg)?.to_device(vb.device())?; + Ok(Self { + wqkv, + out_proj, + kv_cache: None, + softmax_scale, + head_dim, + d_model: cfg.d_model, + n_heads: cfg.n_heads, + kv_n_heads: cfg.kv_n_heads, + attn_bias, + span: tracing::span!(tracing::Level::TRACE, "gqa"), + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let (b_size, seq_len, _n_embd) = xs.dims3()?; + let qkv = self.wqkv.forward(xs)?; + let query = qkv.narrow(2, 0, self.d_model)?; + let kv_size = self.kv_n_heads * self.head_dim; + let key = qkv.narrow(2, self.d_model, kv_size)?; + let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?; + // scaled_multihead_dot_product_attention + let query = query + .reshape((b_size, seq_len, self.n_heads, ()))? + .transpose(1, 2)?; // b,h,s,d + let key = key + .reshape((b_size, seq_len, self.kv_n_heads, ()))? + .permute((0, 2, 3, 1))?; // b,h,d,s + let value = value + .reshape((b_size, seq_len, self.kv_n_heads, ()))? + .transpose(1, 2)?; // b,h,s,d + let (key, value) = match &self.kv_cache { + None => (key, value), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &key], 3)?; + let v = Tensor::cat(&[prev_v, &value], 2)?; + (k, v) + } + }; + self.kv_cache = Some((key.clone(), value.clone())); + let query = query.contiguous()?; + let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; + let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; + let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; + let attn_bias = { + let s_q = query.dim(D::Minus2)?; + let s_k = key.dim(D::Minus1)?; + let (_, _, a_q, a_k) = self.attn_bias.dims4()?; + let start_q = a_q.saturating_sub(s_q); + let start_k = a_k.saturating_sub(s_k); + self.attn_bias.i((.., .., start_q.., start_k..))? + }; + let attn_weights = attn_weights.broadcast_add(&attn_bias)?; + let attn_weights = match mask { + None => attn_weights, + Some(mask) => super::mpt::masked_fill( + &attn_weights, + &mask.broadcast_as(attn_weights.shape())?, + f32::NEG_INFINITY, + )?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights + .matmul(&value)? + .transpose(1, 2)? + .flatten_from(D::Minus2)?; + let out = attn_output.apply(&self.out_proj)?; + Ok(out) + } +} + +#[derive(Debug)] +struct Ffn { + up_proj: Linear, + down_proj: Linear, +} + +impl Ffn { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden = cfg.d_model * cfg.expansion_ratio; + let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?; + Ok(Self { up_proj, down_proj }) + } +} + +impl Module for Ffn { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj) + } +} + +#[derive(Debug)] +struct MPTBlock { + norm1: LayerNorm, // Do we need the low-precision variant? + attn: GroupedQueryAttention, + norm2: LayerNorm, + ffn: Ffn, +} + +impl MPTBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let norm1 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_1"))?; + let norm2 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_2"))?; + let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?; + let ffn = Ffn::new(cfg, vb.pp("ffn"))?; + Ok(Self { + norm1, + attn, + norm2, + ffn, + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = xs.apply(&self.norm1)?; + let xs = self.attn.forward(&xs, mask)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?; + xs + residual + } +} + +#[derive(Debug)] +pub struct Model { + wte: Embedding, + blocks: Vec, + norm_f: LayerNorm, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?; + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(cfg.n_layers); + for i in 0..cfg.n_layers { + let block = MPTBlock::new(cfg, vb_b.pp(i))?; + blocks.push(block) + } + let norm_f = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + Ok(Self { + wte, + blocks, + norm_f, + }) + } + + pub fn forward(&mut self, xs: &Tensor) -> Result { + let (_b_size, seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.wte)?; + let mask = if seq_len <= 1 { + None + } else { + Some(super::mpt::get_mask(seq_len, xs.device())?) + }; + for block in self.blocks.iter_mut() { + xs = block.forward(&xs, mask.as_ref())?; + } + let xs = xs.apply(&self.norm_f)?; + let logits = xs + .narrow(1, seq_len - 1, 1)? + .squeeze(1)? + .matmul(&self.wte.embeddings().t()?)? + .squeeze(1)?; + Ok(logits) + } +} diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index 1745327d..d71c3b60 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -59,6 +59,11 @@ pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result Result { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + Ok(candle_nn::LayerNorm::new_no_bias(weight, eps)) +} + pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { let weight = QMatMul::new(in_dim, out_dim, vb)?; Ok(Linear { weight, bias: None })