mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the quantized mpt model. (#1123)
* Add the quantized mpt model. * Support the quantized model for replit-code.
This commit is contained in:
@ -11,6 +11,7 @@ pub mod mpt;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod segment_anything;
|
||||
|
@ -137,7 +137,7 @@ impl GroupedQueryAttention {
|
||||
// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
|
||||
// (batch, num_attention_heads, seqlen, head_dim)
|
||||
fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||
pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
@ -206,7 +206,7 @@ impl MPTBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
let full = !cfg.is_causal();
|
||||
let seq_len = cfg.max_seq_len;
|
||||
let alibi_bias = Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu)?;
|
||||
@ -289,14 +289,14 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
pub(crate) fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
pub(crate) fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
|
201
candle-transformers/src/models/quantized_mpt.rs
Normal file
201
candle-transformers/src/models/quantized_mpt.rs
Normal file
@ -0,0 +1,201 @@
|
||||
use crate::quantized_nn::{layer_norm_no_bias, linear_no_bias, Embedding, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
/// MPT model used by replit-code-v1_5-3b
|
||||
/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::LayerNorm;
|
||||
|
||||
pub use super::mpt::Config;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct GroupedQueryAttention {
|
||||
wqkv: Linear,
|
||||
out_proj: Linear,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
softmax_scale: f64,
|
||||
head_dim: usize,
|
||||
d_model: usize,
|
||||
n_heads: usize,
|
||||
kv_n_heads: usize,
|
||||
attn_bias: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl GroupedQueryAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let head_dim = cfg.d_model / cfg.n_heads;
|
||||
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
|
||||
let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
|
||||
let attn_bias = super::mpt::build_alibi_bias(cfg)?.to_device(vb.device())?;
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
out_proj,
|
||||
kv_cache: None,
|
||||
softmax_scale,
|
||||
head_dim,
|
||||
d_model: cfg.d_model,
|
||||
n_heads: cfg.n_heads,
|
||||
kv_n_heads: cfg.kv_n_heads,
|
||||
attn_bias,
|
||||
span: tracing::span!(tracing::Level::TRACE, "gqa"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let qkv = self.wqkv.forward(xs)?;
|
||||
let query = qkv.narrow(2, 0, self.d_model)?;
|
||||
let kv_size = self.kv_n_heads * self.head_dim;
|
||||
let key = qkv.narrow(2, self.d_model, kv_size)?;
|
||||
let value = qkv.narrow(2, self.d_model + kv_size, kv_size)?;
|
||||
// scaled_multihead_dot_product_attention
|
||||
let query = query
|
||||
.reshape((b_size, seq_len, self.n_heads, ()))?
|
||||
.transpose(1, 2)?; // b,h,s,d
|
||||
let key = key
|
||||
.reshape((b_size, seq_len, self.kv_n_heads, ()))?
|
||||
.permute((0, 2, 3, 1))?; // b,h,d,s
|
||||
let value = value
|
||||
.reshape((b_size, seq_len, self.kv_n_heads, ()))?
|
||||
.transpose(1, 2)?; // b,h,s,d
|
||||
let (key, value) = match &self.kv_cache {
|
||||
None => (key, value),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &key], 3)?;
|
||||
let v = Tensor::cat(&[prev_v, &value], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key.clone(), value.clone()));
|
||||
let query = query.contiguous()?;
|
||||
let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||
let attn_bias = {
|
||||
let s_q = query.dim(D::Minus2)?;
|
||||
let s_k = key.dim(D::Minus1)?;
|
||||
let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
|
||||
let start_q = a_q.saturating_sub(s_q);
|
||||
let start_k = a_k.saturating_sub(s_k);
|
||||
self.attn_bias.i((.., .., start_q.., start_k..))?
|
||||
};
|
||||
let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
|
||||
let attn_weights = match mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => super::mpt::masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_as(attn_weights.shape())?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights
|
||||
.matmul(&value)?
|
||||
.transpose(1, 2)?
|
||||
.flatten_from(D::Minus2)?;
|
||||
let out = attn_output.apply(&self.out_proj)?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Ffn {
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
}
|
||||
|
||||
impl Ffn {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden = cfg.d_model * cfg.expansion_ratio;
|
||||
let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
|
||||
Ok(Self { up_proj, down_proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Ffn {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.up_proj)?.gelu_erf()?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MPTBlock {
|
||||
norm1: LayerNorm, // Do we need the low-precision variant?
|
||||
attn: GroupedQueryAttention,
|
||||
norm2: LayerNorm,
|
||||
ffn: Ffn,
|
||||
}
|
||||
|
||||
impl MPTBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let norm1 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
|
||||
let norm2 = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
|
||||
let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
|
||||
let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
ffn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.norm1)?;
|
||||
let xs = self.attn.forward(&xs, mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Model {
|
||||
wte: Embedding,
|
||||
blocks: Vec<MPTBlock>,
|
||||
norm_f: LayerNorm,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(cfg.n_layers);
|
||||
for i in 0..cfg.n_layers {
|
||||
let block = MPTBlock::new(cfg, vb_b.pp(i))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let norm_f = layer_norm_no_bias(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
norm_f,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.wte)?;
|
||||
let mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(super::mpt::get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?;
|
||||
}
|
||||
let xs = xs.apply(&self.norm_f)?;
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.squeeze(1)?
|
||||
.matmul(&self.wte.embeddings().t()?)?
|
||||
.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
@ -59,6 +59,11 @@ pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::La
|
||||
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
pub fn layer_norm_no_bias(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||
Ok(candle_nn::LayerNorm::new_no_bias(weight, eps))
|
||||
}
|
||||
|
||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||
Ok(Linear { weight, bias: None })
|
||||
|
Reference in New Issue
Block a user