MPT fixes. (#1117)

* MPT fixes.

* Another couple fixes.

* Another shape fix.
This commit is contained in:
Laurent Mazare
2023-10-17 21:53:31 +01:00
committed by GitHub
parent a72b50e2c0
commit 2cd745a97c
2 changed files with 22 additions and 13 deletions

View File

@ -215,7 +215,7 @@ fn main() -> Result<()> {
let config = Config::replit_code_v1_5_3b();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::new(&config, vb)?;
let model = Model::new(&config, vb.pp("transformer"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(

View File

@ -1,5 +1,5 @@
#![allow(unused)]
use crate::models::with_tracing::{linear, Embedding as E, Linear};
use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
/// MPT model used by replit-code-v1_5-3b
/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
@ -57,11 +57,11 @@ struct GroupedQueryAttention {
impl GroupedQueryAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads;
let wqkv = linear(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
let head_dim = cfg.d_model / cfg.n_heads;
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
Ok(Self {
wqkv,
@ -155,8 +155,8 @@ struct Ffn {
impl Ffn {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden = cfg.d_model * cfg.expansion_ratio;
let down_proj = linear(cfg.d_model, hidden, vb.pp("down_proj"))?;
let up_proj = linear(hidden, cfg.d_model, vb.pp("up_proj"))?;
let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
Ok(Self { up_proj, down_proj })
}
}
@ -177,8 +177,12 @@ struct MPTBlock {
impl MPTBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm1 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
let norm2 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
let ln_cfg = candle_nn::LayerNormConfig {
affine: false,
..Default::default()
};
let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?;
let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?;
let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
Ok(Self {
@ -212,7 +216,7 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
alibi_bias.reshape((1, 1, 1, seq_len))?
};
let mut n_heads2 = 1;
while 2 * n_heads2 <= cfg.n_heads {
while n_heads2 < cfg.n_heads {
n_heads2 *= 2
}
let slopes = (1..=n_heads2)
@ -230,8 +234,8 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
.cloned()
.collect::<Vec<f32>>()
};
let slopes = Tensor::new(slopes, &Device::Cpu)?;
alibi_bias.broadcast_mul(&slopes)
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
}
#[derive(Debug)]
@ -250,7 +254,11 @@ impl Model {
let block = MPTBlock::new(cfg, vb_b.pp(i))?;
blocks.push(block)
}
let norm_f = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
let ln_cfg = candle_nn::LayerNormConfig {
affine: false,
..Default::default()
};
let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?;
Ok(Self {
wte,
blocks,
@ -270,6 +278,7 @@ impl Model {
xs = block.forward(&xs, mask.as_ref())?
}
xs.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.matmul(&self.wte.embeddings().t()?)?
.squeeze(1)
}