mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
MPT fixes. (#1117)
* MPT fixes. * Another couple fixes. * Another shape fix.
This commit is contained in:
@ -215,7 +215,7 @@ fn main() -> Result<()> {
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = Model::new(&config, vb.pp("transformer"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
|
@ -1,5 +1,5 @@
|
||||
#![allow(unused)]
|
||||
use crate::models::with_tracing::{linear, Embedding as E, Linear};
|
||||
use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
|
||||
/// MPT model used by replit-code-v1_5-3b
|
||||
/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
@ -57,11 +57,11 @@ struct GroupedQueryAttention {
|
||||
|
||||
impl GroupedQueryAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads;
|
||||
let wqkv = linear(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
|
||||
let head_dim = cfg.d_model / cfg.n_heads;
|
||||
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
|
||||
let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
|
||||
let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
|
||||
let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
@ -155,8 +155,8 @@ struct Ffn {
|
||||
impl Ffn {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden = cfg.d_model * cfg.expansion_ratio;
|
||||
let down_proj = linear(cfg.d_model, hidden, vb.pp("down_proj"))?;
|
||||
let up_proj = linear(hidden, cfg.d_model, vb.pp("up_proj"))?;
|
||||
let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
|
||||
Ok(Self { up_proj, down_proj })
|
||||
}
|
||||
}
|
||||
@ -177,8 +177,12 @@ struct MPTBlock {
|
||||
|
||||
impl MPTBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let norm1 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
|
||||
let norm2 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
|
||||
let ln_cfg = candle_nn::LayerNormConfig {
|
||||
affine: false,
|
||||
..Default::default()
|
||||
};
|
||||
let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?;
|
||||
let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?;
|
||||
let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
|
||||
let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
|
||||
Ok(Self {
|
||||
@ -212,7 +216,7 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
alibi_bias.reshape((1, 1, 1, seq_len))?
|
||||
};
|
||||
let mut n_heads2 = 1;
|
||||
while 2 * n_heads2 <= cfg.n_heads {
|
||||
while n_heads2 < cfg.n_heads {
|
||||
n_heads2 *= 2
|
||||
}
|
||||
let slopes = (1..=n_heads2)
|
||||
@ -230,8 +234,8 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
.cloned()
|
||||
.collect::<Vec<f32>>()
|
||||
};
|
||||
let slopes = Tensor::new(slopes, &Device::Cpu)?;
|
||||
alibi_bias.broadcast_mul(&slopes)
|
||||
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
|
||||
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -250,7 +254,11 @@ impl Model {
|
||||
let block = MPTBlock::new(cfg, vb_b.pp(i))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let norm_f = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
||||
let ln_cfg = candle_nn::LayerNormConfig {
|
||||
affine: false,
|
||||
..Default::default()
|
||||
};
|
||||
let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?;
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
@ -270,6 +278,7 @@ impl Model {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.squeeze(1)?
|
||||
.matmul(&self.wte.embeddings().t()?)?
|
||||
.squeeze(1)
|
||||
}
|
||||
|
Reference in New Issue
Block a user