MPT fixes. (#1117)

* MPT fixes.

* Another couple fixes.

* Another shape fix.
This commit is contained in:
Laurent Mazare
2023-10-17 21:53:31 +01:00
committed by GitHub
parent a72b50e2c0
commit 2cd745a97c
2 changed files with 22 additions and 13 deletions

View File

@ -215,7 +215,7 @@ fn main() -> Result<()> {
let config = Config::replit_code_v1_5_3b(); let config = Config::replit_code_v1_5_3b();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::new(&config, vb)?; let model = Model::new(&config, vb.pp("transformer"))?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new( let mut pipeline = TextGeneration::new(

View File

@ -1,5 +1,5 @@
#![allow(unused)] #![allow(unused)]
use crate::models::with_tracing::{linear, Embedding as E, Linear}; use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
/// MPT model used by replit-code-v1_5-3b /// MPT model used by replit-code-v1_5-3b
/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py /// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
@ -57,11 +57,11 @@ struct GroupedQueryAttention {
impl GroupedQueryAttention { impl GroupedQueryAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads;
let wqkv = linear(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
let head_dim = cfg.d_model / cfg.n_heads; let head_dim = cfg.d_model / cfg.n_heads;
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt(); let softmax_scale = 1f64 / (head_dim as f64).sqrt();
let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?; let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?; let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
Ok(Self { Ok(Self {
wqkv, wqkv,
@ -155,8 +155,8 @@ struct Ffn {
impl Ffn { impl Ffn {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden = cfg.d_model * cfg.expansion_ratio; let hidden = cfg.d_model * cfg.expansion_ratio;
let down_proj = linear(cfg.d_model, hidden, vb.pp("down_proj"))?; let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
let up_proj = linear(hidden, cfg.d_model, vb.pp("up_proj"))?; let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
Ok(Self { up_proj, down_proj }) Ok(Self { up_proj, down_proj })
} }
} }
@ -177,8 +177,12 @@ struct MPTBlock {
impl MPTBlock { impl MPTBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm1 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_1"))?; let ln_cfg = candle_nn::LayerNormConfig {
let norm2 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_2"))?; affine: false,
..Default::default()
};
let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?;
let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?;
let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?; let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
let ffn = Ffn::new(cfg, vb.pp("ffn"))?; let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
Ok(Self { Ok(Self {
@ -212,7 +216,7 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
alibi_bias.reshape((1, 1, 1, seq_len))? alibi_bias.reshape((1, 1, 1, seq_len))?
}; };
let mut n_heads2 = 1; let mut n_heads2 = 1;
while 2 * n_heads2 <= cfg.n_heads { while n_heads2 < cfg.n_heads {
n_heads2 *= 2 n_heads2 *= 2
} }
let slopes = (1..=n_heads2) let slopes = (1..=n_heads2)
@ -230,8 +234,8 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
.cloned() .cloned()
.collect::<Vec<f32>>() .collect::<Vec<f32>>()
}; };
let slopes = Tensor::new(slopes, &Device::Cpu)?; let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
alibi_bias.broadcast_mul(&slopes) alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
} }
#[derive(Debug)] #[derive(Debug)]
@ -250,7 +254,11 @@ impl Model {
let block = MPTBlock::new(cfg, vb_b.pp(i))?; let block = MPTBlock::new(cfg, vb_b.pp(i))?;
blocks.push(block) blocks.push(block)
} }
let norm_f = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; let ln_cfg = candle_nn::LayerNormConfig {
affine: false,
..Default::default()
};
let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?;
Ok(Self { Ok(Self {
wte, wte,
blocks, blocks,
@ -270,6 +278,7 @@ impl Model {
xs = block.forward(&xs, mask.as_ref())? xs = block.forward(&xs, mask.as_ref())?
} }
xs.narrow(1, seq_len - 1, 1)? xs.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.matmul(&self.wte.embeddings().t()?)? .matmul(&self.wte.embeddings().t()?)?
.squeeze(1) .squeeze(1)
} }