mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Line-up the llama implementation with the python-transformers one. (#271)
* Line-up the llama implementation with the python-transformers one. * Also lineup the multiprocess version.
This commit is contained in:
@ -14,6 +14,7 @@ pub struct Config {
|
|||||||
pub n_embd: usize,
|
pub n_embd: usize,
|
||||||
pub n_key_value_head: usize,
|
pub n_key_value_head: usize,
|
||||||
pub use_flash_attn: bool,
|
pub use_flash_attn: bool,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -27,6 +28,7 @@ impl Config {
|
|||||||
n_embd: 4096,
|
n_embd: 4096,
|
||||||
n_key_value_head: 32,
|
n_key_value_head: 32,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
|
rms_norm_eps: 1e-5,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,16 +104,13 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
|||||||
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
|
eps: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
|
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||||
let scale = vb.get(size, "weight")?;
|
let scale = vb.get(size, "weight")?;
|
||||||
Ok(Self::new(scale))
|
Ok(Self { scale, eps })
|
||||||
}
|
|
||||||
|
|
||||||
fn new(scale: Tensor) -> Self {
|
|
||||||
Self { scale }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -121,7 +120,7 @@ impl RmsNorm {
|
|||||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||||
let size = self.scale.dims1()?;
|
let size = self.scale.dims1()?;
|
||||||
let scale = self
|
let scale = self
|
||||||
.scale
|
.scale
|
||||||
@ -292,14 +291,6 @@ struct Mlp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
|
||||||
Self {
|
|
||||||
c_fc1,
|
|
||||||
c_fc2,
|
|
||||||
c_proj,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||||
self.c_proj.forward(&x)
|
self.c_proj.forward(&x)
|
||||||
@ -311,7 +302,11 @@ impl Mlp {
|
|||||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||||
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
||||||
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
||||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
Ok(Self {
|
||||||
|
c_fc1,
|
||||||
|
c_fc2,
|
||||||
|
c_proj,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,15 +318,6 @@ struct Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
|
||||||
Self {
|
|
||||||
rms_1,
|
|
||||||
attn,
|
|
||||||
rms_2,
|
|
||||||
mlp,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
let residual = x;
|
let residual = x;
|
||||||
let x = self.rms_1.forward(x)?;
|
let x = self.rms_1.forward(x)?;
|
||||||
@ -344,15 +330,18 @@ impl Block {
|
|||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
let post_attention_layernorm =
|
let rms_2 = RmsNorm::load(
|
||||||
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
|
cfg.hidden_size,
|
||||||
Ok(Self::new(
|
cfg.rms_norm_eps,
|
||||||
input_layernorm,
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
rms_1,
|
||||||
attn,
|
attn,
|
||||||
post_attention_layernorm,
|
rms_2,
|
||||||
mlp,
|
mlp,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -364,15 +353,6 @@ pub struct Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Llama {
|
impl Llama {
|
||||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
|
||||||
Self {
|
|
||||||
wte,
|
|
||||||
blocks,
|
|
||||||
ln_f,
|
|
||||||
lm_head,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
@ -388,11 +368,16 @@ impl Llama {
|
|||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
Ok(Self {
|
||||||
|
wte,
|
||||||
|
blocks,
|
||||||
|
ln_f,
|
||||||
|
lm_head,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -225,7 +225,7 @@ impl RmsNorm {
|
|||||||
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
|
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
let size = self.scale.shape().dims1()?;
|
let size = self.scale.shape().dims1()?;
|
||||||
let scale = self
|
let scale = self
|
||||||
.scale
|
.scale
|
||||||
|
Reference in New Issue
Block a user