From 0007ae9c119d567e3143000e9321cdbfe32b4da2 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 24 Sep 2023 15:03:48 +0100 Subject: [PATCH] Add the quantized mixformer model. (#953) * Add the quantized mixformer model. * Add the quantized option in the phi example. --- candle-examples/examples/phi/main.rs | 32 +- candle-transformers/src/models/mixformer.rs | 22 +- candle-transformers/src/models/mod.rs | 1 + .../src/models/quantized_mixformer.rs | 344 ++++++++++++++++++ .../src/models/quantized_t5.rs | 35 +- .../src/models/with_tracing.rs | 32 ++ 6 files changed, 418 insertions(+), 48 deletions(-) create mode 100644 candle-transformers/src/models/quantized_mixformer.rs diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 3b1e7dc1..9bd7e8da 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as Model}; +use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer}; +use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -15,6 +16,11 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + MixFormer(MixFormer), + Quantized(QMixFormer), +} + struct TextGeneration { model: Model, device: Device, @@ -58,7 +64,10 @@ impl TextGeneration { let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.forward(&input)?; + let logits = match &mut self.model { + Model::MixFormer(m) => m.forward(&input)?, + Model::Quantized(m) => m.forward(&input)?, + }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let next_token = self.logits_processor.sample(&logits)?; @@ -115,6 +124,9 @@ struct Args { #[arg(long)] weight_file: Option, + + #[arg(long)] + quantized: bool, } fn main() -> Result<()> { @@ -150,10 +162,18 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let config = Config::v1_5(); - let model = Model::new(&config, vb)?; + let (model, device) = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let config = Config::v1_5(); + let model = QMixFormer::new(&config, vb)?; + (Model::Quantized(model), Device::Cpu) + } else { + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let config = Config::v1_5(); + let model = MixFormer::new(&config, vb)?; + (Model::MixFormer(model), device) + }; println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new( diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 6a3b5515..e945cd51 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -10,17 +10,17 @@ const MAX_SEQ_LEN: usize = 4096; // https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py #[derive(Debug, Clone, PartialEq)] pub struct Config { - vocab_size: usize, - n_positions: usize, - n_embd: usize, - n_layer: usize, - n_inner: Option, - n_head: usize, - rotary_dim: usize, - activation_function: Activation, - layer_norm_epsilon: f64, - tie_word_embeddings: bool, - pad_vocab_size_multiple: usize, + pub(crate) vocab_size: usize, + pub(crate) n_positions: usize, + pub(crate) n_embd: usize, + pub(crate) n_layer: usize, + pub(crate) n_inner: Option, + pub(crate) n_head: usize, + pub(crate) rotary_dim: usize, + pub(crate) activation_function: Activation, + pub(crate) layer_norm_epsilon: f64, + pub(crate) tie_word_embeddings: bool, + pub(crate) pad_vocab_size_multiple: usize, } impl Config { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 0fbcaa07..d6d5edf3 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -6,6 +6,7 @@ pub mod falcon; pub mod llama; pub mod mixformer; pub mod quantized_llama; +pub mod quantized_mixformer; pub mod quantized_t5; pub mod segment_anything; pub mod stable_diffusion; diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs new file mode 100644 index 00000000..4ace2045 --- /dev/null +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -0,0 +1,344 @@ +use crate::models::with_tracing::QMatMul; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::Activation; + +pub use crate::models::mixformer::Config; + +const MAX_SEQ_LEN: usize = 4096; + +#[derive(Debug)] +struct Embedding { + wte: super::quantized_t5::Embedding, +} + +impl Embedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let wte = super::quantized_t5::Embedding::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?; + Ok(Self { wte }) + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result { + self.wte.forward(xs) + } +} + +#[derive(Debug)] +struct Linear { + weight: QMatMul, + bias: Option, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result { + let x = x.apply(&self.weight)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { + let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { + weight, + bias: Some(bias), + }) +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let bias = vb.get(size, "bias")?.dequantize(vb.device())?; + Ok(candle_nn::LayerNorm::new(weight, bias, eps)) +} + +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result { + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + qkv: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor, Tensor)> { + let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?; + if three != 3 { + candle::bail!("unexpected shape for qkv {:?}", qkv.shape()) + } + let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?; + let rotary_dim = rotary_dim * 2; + let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?; + let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?; + let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?; + let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?; + let q12 = q_rot.chunk(2, D::Minus1)?; + let k12 = k_rot.chunk(2, D::Minus1)?; + let (q1, q2) = (&q12[0], &q12[1]); + let (k1, k2) = (&k12[0], &k12[1]); + let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; + let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; + let q_rot = Tensor::cat( + &[ + (q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?, + (q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?, + ], + D::Minus1, + )?; + let k_rot = Tensor::cat( + &[ + (k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?, + (k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?, + ], + D::Minus1, + )?; + let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?; + let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?; + let v = qkv.i((.., .., 2))?; + Ok((q, k, v)) + } +} + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + fc1: Linear, + fc2: Linear, + act: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd); + let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?; + let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?; + Ok(Self { + fc1, + fc2, + act: cfg.activation_function, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) + } +} + +#[derive(Debug)] +struct CausalLMHead { + ln: candle_nn::LayerNorm, + linear: Linear, +} + +impl CausalLMHead { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let ln = layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?; + let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?; + Ok(Self { ln, linear }) + } +} + +impl Module for CausalLMHead { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.ln)? + .apply(&self.linear)? + .to_dtype(DType::F32) + } +} + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +struct MHA { + wqkv: Linear, + out_proj: Linear, + rotary_emb: RotaryEmbedding, + kv_cache: Option<(Tensor, Tensor)>, + head_dim: usize, + n_head: usize, + softmax_scale: f64, + span: tracing::Span, +} + +impl MHA { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let head_dim = cfg.n_embd / cfg.n_head; + let op_size = cfg.n_embd; + let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; + let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; + let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?; + let softmax_scale = 1f64 / (head_dim as f64).sqrt(); + Ok(Self { + wqkv, + out_proj, + head_dim, + n_head: cfg.n_head, + kv_cache: None, + rotary_emb, + softmax_scale, + span: tracing::span!(tracing::Level::TRACE, "mha"), + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let (b_size, seq_len, _n_embd) = xs.dims3()?; + let qkv = self + .wqkv + .forward(xs)? + .reshape((b_size, seq_len, 3, (), self.head_dim))?; + let seqlen_offset = match &self.kv_cache { + None => 0, + Some((prev_k, _)) => prev_k.dim(1)?, + }; + // In the python implementation, a single tensor is returned with the third axis of size 3. + let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &k], 1)?; + let v = Tensor::cat(&[prev_v, &v], 1)?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + // scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + let q = q.transpose(1, 2)?.flatten_to(1)?; // b*h, t, d + let k = k.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d + let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d + let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s + + // causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1) + // scores = scores + causal_mask.to(dtype=scores.dtype) + let attn_weights = match mask { + None => attn_weights, + Some(mask) => masked_fill( + &attn_weights, + &mask.broadcast_left(b_size * self.n_head)?, + f32::NEG_INFINITY, + )?, + }; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + + // output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + // attn_weights: b*h,t,s, v: b*h,s,d + let attn_output = attn_weights.matmul(&v)?; + // b*h,t,d + let attn_output = attn_output + .reshape((b_size, (), seq_len, self.head_dim))? + .transpose(1, 2)? + .flatten_from(D::Minus2)?; + attn_output.apply(&self.out_proj) + } +} + +#[derive(Debug)] +struct ParallelBlock { + ln: candle_nn::LayerNorm, + mixer: MHA, + mlp: MLP, + span: tracing::Span, +} + +impl ParallelBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let ln = layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?; + let mixer = MHA::new(cfg, vb.pp("mixer"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + Ok(Self { + ln, + mixer, + mlp, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); + let residual = xs; + let xs = xs.apply(&self.ln)?; + let attn_outputs = self.mixer.forward(&xs, mask)?; + let feed_forward_hidden_states = self.mlp.forward(&xs)?; + attn_outputs + feed_forward_hidden_states + residual + } +} + +#[derive(Debug)] +pub struct MixFormerSequentialForCausalLM { + embedding: Embedding, + blocks: Vec, + head: CausalLMHead, + span: tracing::Span, +} + +impl MixFormerSequentialForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb = vb.pp("layers"); + let embedding = Embedding::new(cfg, vb.pp(0))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp(i + 1))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?; + Ok(Self { + embedding, + blocks, + head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), + }) + } + + pub fn forward(&mut self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let (_b_size, seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embedding)?; + let mask = if seq_len <= 1 { + None + } else { + Some(get_mask(seq_len, xs.device())?) + }; + for block in self.blocks.iter_mut() { + xs = block.forward(&xs, mask.as_ref())? + } + xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) + } +} diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index d2fa0e2d..7f7d53dd 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -1,6 +1,7 @@ // T5 Text Model, quantized version // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py +use crate::models::with_tracing::QMatMul; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::Activation; @@ -8,20 +9,20 @@ use serde::Deserialize; use std::sync::Arc; #[derive(Debug)] -struct Embedding { +pub struct Embedding { inner: candle_nn::Embedding, span: tracing::Span, } impl Embedding { - fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { + pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; let inner = candle_nn::Embedding::new(embeddings, d2); let span = tracing::span!(tracing::Level::TRACE, "embedding"); Ok(Self { inner, span }) } - fn embeddings(&self) -> &Tensor { + pub fn embeddings(&self) -> &Tensor { self.inner.embeddings() } } @@ -33,34 +34,6 @@ impl Module for Embedding { } } -// QMatMul wrapper adding some tracing. -struct QMatMul { - inner: candle::quantized::QMatMul, - span: tracing::Span, -} - -impl QMatMul { - fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result { - let ws = vb.get((in_dim, out_dim), "weight")?; - let inner = candle::quantized::QMatMul::from_arc(ws); - let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - Ok(Self { inner, span }) - } -} - -impl Module for QMatMul { - fn forward(&self, xs: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - -impl std::fmt::Debug for QMatMul { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "QMatMul") - } -} - fn default_relative_attention_max_distance() -> usize { 128 } diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 0a2d65b9..6a6c69e7 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -76,3 +76,35 @@ pub fn conv2d( let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?; Ok(Conv2d { inner, span }) } + +// QMatMul wrapper adding some tracing. +pub struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + pub fn new( + out_dim: usize, + in_dim: usize, + vb: crate::quantized_var_builder::VarBuilder, + ) -> Result { + let ws = vb.get((in_dim, out_dim), "weight")?; + let inner = candle::quantized::QMatMul::from_arc(ws); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } +} + +impl Module for QMatMul { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +impl std::fmt::Debug for QMatMul { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QMatMul") + } +}