From 3e3def41346daae52a0e438513b282e8bba14e73 Mon Sep 17 00:00:00 2001 From: Frkri <57220519+frkri@users.noreply.github.com> Date: Sat, 2 Mar 2024 09:56:57 +0100 Subject: [PATCH] Update StableLM config (#1787) --- .../src/models/quantized_stable_lm.rs | 6 +++--- candle-transformers/src/models/stable_lm.rs | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index fa301f6c..c79877b6 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -186,10 +186,10 @@ impl DecoderLayer { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let input_layernorm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?; + let input_layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = layer_norm( cfg.hidden_size, - cfg.norm_eps, + cfg.layer_norm_eps, vb.pp("post_attention_layernorm"), )?; Ok(Self { @@ -240,7 +240,7 @@ impl Model { let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?; + let norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?; let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; Ok(Self { embed_tokens, diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index a49b8282..f46d3a2c 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -4,7 +4,7 @@ use candle_nn::{Activation, LayerNorm, VarBuilder}; use serde::Deserialize; use std::sync::Arc; -// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py +// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm.py #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub(crate) vocab_size: usize, @@ -14,10 +14,10 @@ pub struct Config { pub(crate) num_attention_heads: usize, pub(crate) num_key_value_heads: usize, pub(crate) hidden_act: Activation, - pub(crate) rope_pct: f64, + pub(crate) partial_rotary_factor: f64, pub(crate) rope_theta: f64, pub(crate) max_position_embeddings: usize, - pub(crate) norm_eps: f64, + pub(crate) layer_norm_eps: f64, pub(crate) use_cache: bool, #[serde(default)] pub(crate) use_qkv_bias: bool, // Used in StableLM-2 @@ -35,10 +35,10 @@ impl Config { num_attention_heads: 32, num_key_value_heads: 32, hidden_act: Activation::Silu, - rope_pct: 0.25, + partial_rotary_factor: 0.25, rope_theta: 10_000., max_position_embeddings: 4096, - norm_eps: 1e-5, + layer_norm_eps: 1e-5, use_qkv_bias: false, use_cache: true, use_flash_attn, @@ -50,7 +50,7 @@ impl Config { } pub fn rotary_ndims(&self) -> usize { - (self.head_dim() as f64 * self.rope_pct) as usize + (self.head_dim() as f64 * self.partial_rotary_factor) as usize } pub fn num_kv_groups(&self) -> usize { @@ -317,10 +317,10 @@ impl DecoderLayer { let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; let input_layernorm = - candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?; + candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = candle_nn::layer_norm( cfg.hidden_size, - cfg.norm_eps, + cfg.layer_norm_eps, vb.pp("post_attention_layernorm"), )?; Ok(Self { @@ -372,7 +372,7 @@ impl Model { let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?; + let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?; let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; Ok(Self { embed_tokens,