mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
MPT fixes. (#1117)
* MPT fixes. * Another couple fixes. * Another shape fix.
This commit is contained in:
@ -215,7 +215,7 @@ fn main() -> Result<()> {
|
|||||||
let config = Config::replit_code_v1_5_3b();
|
let config = Config::replit_code_v1_5_3b();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb.pp("transformer"))?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
let mut pipeline = TextGeneration::new(
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
use crate::models::with_tracing::{linear, Embedding as E, Linear};
|
use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
|
||||||
/// MPT model used by replit-code-v1_5-3b
|
/// MPT model used by replit-code-v1_5-3b
|
||||||
/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
|
/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
@ -57,11 +57,11 @@ struct GroupedQueryAttention {
|
|||||||
|
|
||||||
impl GroupedQueryAttention {
|
impl GroupedQueryAttention {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads;
|
|
||||||
let wqkv = linear(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
|
|
||||||
let head_dim = cfg.d_model / cfg.n_heads;
|
let head_dim = cfg.d_model / cfg.n_heads;
|
||||||
|
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
|
||||||
|
let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
|
||||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||||
let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
|
let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
|
||||||
let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
|
let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wqkv,
|
wqkv,
|
||||||
@ -155,8 +155,8 @@ struct Ffn {
|
|||||||
impl Ffn {
|
impl Ffn {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let hidden = cfg.d_model * cfg.expansion_ratio;
|
let hidden = cfg.d_model * cfg.expansion_ratio;
|
||||||
let down_proj = linear(cfg.d_model, hidden, vb.pp("down_proj"))?;
|
let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
|
||||||
let up_proj = linear(hidden, cfg.d_model, vb.pp("up_proj"))?;
|
let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
|
||||||
Ok(Self { up_proj, down_proj })
|
Ok(Self { up_proj, down_proj })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,8 +177,12 @@ struct MPTBlock {
|
|||||||
|
|
||||||
impl MPTBlock {
|
impl MPTBlock {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let norm1 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
|
let ln_cfg = candle_nn::LayerNormConfig {
|
||||||
let norm2 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
|
affine: false,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?;
|
||||||
|
let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?;
|
||||||
let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
|
let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
|
||||||
let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
|
let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -212,7 +216,7 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
|||||||
alibi_bias.reshape((1, 1, 1, seq_len))?
|
alibi_bias.reshape((1, 1, 1, seq_len))?
|
||||||
};
|
};
|
||||||
let mut n_heads2 = 1;
|
let mut n_heads2 = 1;
|
||||||
while 2 * n_heads2 <= cfg.n_heads {
|
while n_heads2 < cfg.n_heads {
|
||||||
n_heads2 *= 2
|
n_heads2 *= 2
|
||||||
}
|
}
|
||||||
let slopes = (1..=n_heads2)
|
let slopes = (1..=n_heads2)
|
||||||
@ -230,8 +234,8 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
|||||||
.cloned()
|
.cloned()
|
||||||
.collect::<Vec<f32>>()
|
.collect::<Vec<f32>>()
|
||||||
};
|
};
|
||||||
let slopes = Tensor::new(slopes, &Device::Cpu)?;
|
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
|
||||||
alibi_bias.broadcast_mul(&slopes)
|
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -250,7 +254,11 @@ impl Model {
|
|||||||
let block = MPTBlock::new(cfg, vb_b.pp(i))?;
|
let block = MPTBlock::new(cfg, vb_b.pp(i))?;
|
||||||
blocks.push(block)
|
blocks.push(block)
|
||||||
}
|
}
|
||||||
let norm_f = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
let ln_cfg = candle_nn::LayerNormConfig {
|
||||||
|
affine: false,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wte,
|
wte,
|
||||||
blocks,
|
blocks,
|
||||||
@ -270,6 +278,7 @@ impl Model {
|
|||||||
xs = block.forward(&xs, mask.as_ref())?
|
xs = block.forward(&xs, mask.as_ref())?
|
||||||
}
|
}
|
||||||
xs.narrow(1, seq_len - 1, 1)?
|
xs.narrow(1, seq_len - 1, 1)?
|
||||||
|
.squeeze(1)?
|
||||||
.matmul(&self.wte.embeddings().t()?)?
|
.matmul(&self.wte.embeddings().t()?)?
|
||||||
.squeeze(1)
|
.squeeze(1)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user