From 59ab6d7832600083a1519aa0511e9c7c832ae01c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 8 Oct 2023 15:42:38 +0100 Subject: [PATCH] Quantized version of StableLM. (#1058) * Quantized version of StableLM. * Adapt the stable-lm example to support quantizsed. * Use some separate hub repo. * Another repo name tweak. --- candle-examples/examples/stable-lm/main.rs | 32 +- candle-transformers/src/models/mod.rs | 1 + .../src/models/quantized_stable_lm.rs | 299 ++++++++++++++++++ candle-transformers/src/models/stable_lm.rs | 13 +- 4 files changed, 331 insertions(+), 14 deletions(-) create mode 100644 candle-transformers/src/models/quantized_stable_lm.rs diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 95521265..0535aa70 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::stable_lm::{Config, Model}; +use candle_transformers::models::quantized_stable_lm::Model as QStableLM; +use candle_transformers::models::stable_lm::{Config, Model as StableLM}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -16,6 +17,11 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + StableLM(StableLM), + Quantized(QStableLM), +} + struct TextGeneration { model: Model, device: Device, @@ -76,7 +82,10 @@ impl TextGeneration { let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.forward(&input, start_pos)?; + let logits = match &mut self.model { + Model::StableLM(m) => m.forward(&input, start_pos)?, + Model::Quantized(m) => m.forward(&input, start_pos)?, + }; let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let logits = if self.repeat_penalty == 1. { logits @@ -146,7 +155,7 @@ struct Args { #[arg(long, short = 'n', default_value_t = 100)] sample_len: usize, - #[arg(long, default_value = "stabilityai/stablelm-3b-4e1t")] + #[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")] model_id: String, #[arg(long, default_value = "main")] @@ -213,7 +222,11 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::>(), None => { - vec![repo.get("model.safetensors")?] + if args.quantized { + vec![repo.get("model-q4k.gguf")?] + } else { + vec![repo.get("model.safetensors")?] + } } }; println!("retrieved the files in {:?}", start.elapsed()); @@ -221,7 +234,12 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::stablelm_3b_4e1t(args.use_flash_attn); - let (model, device) = { + let (model, device) = if args.quantized { + let filename = &filenames[0]; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let model = QStableLM::new(&config, vb)?; + (Model::Quantized(model), Device::Cpu) + } else { let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 @@ -229,8 +247,8 @@ fn main() -> Result<()> { DType::F32 }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb)?; - (model, device) + let model = StableLM::new(&config, vb)?; + (Model::StableLM(model), device) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 7638dda3..81044112 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -9,6 +9,7 @@ pub mod mixformer; pub mod quantized_llama; pub mod quantized_mistral; pub mod quantized_mixformer; +pub mod quantized_stable_lm; pub mod quantized_t5; pub mod segment_anything; pub mod stable_diffusion; diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs new file mode 100644 index 00000000..86964237 --- /dev/null +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -0,0 +1,299 @@ +use crate::models::quantized_t5::Embedding; +use crate::models::with_tracing::QMatMul; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, LayerNorm}; +use std::sync::Arc; + +pub use crate::models::stable_lm::Config; +use crate::models::stable_lm::RotaryEmbedding; + +#[derive(Debug)] +struct Linear { + weight: QMatMul, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result { + x.apply(&self.weight) + } +} + +fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight }) +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let bias = vb.get(size, "bias")?.dequantize(vb.device())?; + Ok(candle_nn::LayerNorm::new(weight, bias, eps)) +} + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, + use_cache: bool, + rotary_ndims: usize, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let head_dim = cfg.head_dim(); + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups: cfg.num_kv_groups(), + head_dim, + hidden_size: hidden_sz, + rotary_emb, + kv_cache: None, + use_cache: cfg.use_cache, + rotary_ndims: cfg.rotary_ndims(), + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result { + let n_rep = self.num_kv_groups; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims); + let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?; + let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; + let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?; + let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; + let (query_rot, key_rot) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?; + let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?; + let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + if self.use_cache { + self.kv_cache = Some((key_states.clone(), value_states.clone())); + } + + let key_states = self.repeat_kv(key_states)?.contiguous()?; + let value_states = self.repeat_kv(value_states)?.contiguous()?; + + let attn_output = { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: LayerNorm, + post_attention_layernorm: LayerNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = layer_norm( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } +} + +#[derive(Debug)] +pub struct Model { + embed_tokens: Embedding, + layers: Vec, + norm: LayerNorm, + lm_head: Linear, + device: Device, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(DType::F32, cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?; + let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: vb.device().clone(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + // Sliding window mask? + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(DType::F32) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } +} diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index 87e72396..affb28cf 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -1,4 +1,3 @@ -#![allow(unused)] use crate::models::with_tracing::{linear_no_bias, Linear}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm, VarBuilder}; @@ -41,21 +40,21 @@ impl Config { } } - fn head_dim(&self) -> usize { + pub fn head_dim(&self) -> usize { self.hidden_size / self.num_attention_heads } - fn rotary_ndims(&self) -> usize { + pub fn rotary_ndims(&self) -> usize { (self.head_dim() as f64 * self.rope_pct) as usize } - fn num_kv_groups(&self) -> usize { + pub fn num_kv_groups(&self) -> usize { self.num_attention_heads / self.num_key_value_heads } } #[derive(Debug)] -struct RotaryEmbedding { +pub(crate) struct RotaryEmbedding { sin: Tensor, cos: Tensor, } @@ -66,7 +65,7 @@ fn rotate_half(xs: &Tensor) -> Result { } impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.rotary_ndims(); let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) @@ -86,7 +85,7 @@ impl RotaryEmbedding { }) } - fn apply_rotary_emb_qkv( + pub(crate) fn apply_rotary_emb_qkv( &self, q: &Tensor, k: &Tensor,