Quantized version of StableLM. (#1058)

* Quantized version of StableLM.

* Adapt the stable-lm example to support quantizsed.

* Use some separate hub repo.

* Another repo name tweak.
This commit is contained in:
Laurent Mazare
2023-10-08 15:42:38 +01:00
committed by GitHub
parent 783735cf22
commit 59ab6d7832
4 changed files with 331 additions and 14 deletions

View File

@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::stable_lm::{Config, Model};
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +17,11 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
StableLM(StableLM),
Quantized(QStableLM),
}
struct TextGeneration {
model: Model,
device: Device,
@ -76,7 +82,10 @@ impl TextGeneration {
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = match &mut self.model {
Model::StableLM(m) => m.forward(&input, start_pos)?,
Model::Quantized(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
@ -146,7 +155,7 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "stabilityai/stablelm-3b-4e1t")]
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
model_id: String,
#[arg(long, default_value = "main")]
@ -213,15 +222,24 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![repo.get("model.safetensors")?]
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let (model, device) = {
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
@ -229,8 +247,8 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
(model, device)
let model = StableLM::new(&config, vb)?;
(Model::StableLM(model), device)
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -9,6 +9,7 @@ pub mod mixformer;
pub mod quantized_llama;
pub mod quantized_mistral;
pub mod quantized_mixformer;
pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod segment_anything;
pub mod stable_diffusion;

View File

@ -0,0 +1,299 @@
use crate::models::quantized_t5::Embedding;
use crate::models::with_tracing::QMatMul;
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm};
use std::sync::Arc;
pub use crate::models::stable_lm::Config;
use crate::models::stable_lm::RotaryEmbedding;
#[derive(Debug)]
struct Linear {
weight: QMatMul,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
x.apply(&self.weight)
}
}
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let weight = QMatMul::new(in_dim, out_dim, vb)?;
Ok(Linear { weight })
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
}
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
use_cache: bool,
rotary_ndims: usize,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups: cfg.num_kv_groups(),
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
use_cache: cfg.use_cache,
rotary_ndims: cfg.rotary_ndims(),
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims);
let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?;
let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?;
let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
let (query_rot, key_rot) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?;
let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?;
let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
if self.use_cache {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
}
#[derive(Debug)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = layer_norm(
cfg.hidden_size,
cfg.norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
}
#[derive(Debug)]
pub struct Model {
embed_tokens: Embedding,
layers: Vec<DecoderLayer>,
norm: LayerNorm,
lm_head: Linear,
device: Device,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(DType::F32, cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: vb.device().clone(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
// Sliding window mask?
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(DType::F32)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use crate::models::with_tracing::{linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder};
@ -41,21 +40,21 @@ impl Config {
}
}
fn head_dim(&self) -> usize {
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
fn rotary_ndims(&self) -> usize {
pub fn rotary_ndims(&self) -> usize {
(self.head_dim() as f64 * self.rope_pct) as usize
}
fn num_kv_groups(&self) -> usize {
pub fn num_kv_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}
}
#[derive(Debug)]
struct RotaryEmbedding {
pub(crate) struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
@ -66,7 +65,7 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.rotary_ndims();
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
@ -86,7 +85,7 @@ impl RotaryEmbedding {
})
}
fn apply_rotary_emb_qkv(
pub(crate) fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,