mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Update StableLM config (#1787)
This commit is contained in:
@ -186,10 +186,10 @@ impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let input_layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.norm_eps,
|
||||
cfg.layer_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
@ -240,7 +240,7 @@ impl Model {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
|
||||
let norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
|
@ -4,7 +4,7 @@ use candle_nn::{Activation, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
@ -14,10 +14,10 @@ pub struct Config {
|
||||
pub(crate) num_attention_heads: usize,
|
||||
pub(crate) num_key_value_heads: usize,
|
||||
pub(crate) hidden_act: Activation,
|
||||
pub(crate) rope_pct: f64,
|
||||
pub(crate) partial_rotary_factor: f64,
|
||||
pub(crate) rope_theta: f64,
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) norm_eps: f64,
|
||||
pub(crate) layer_norm_eps: f64,
|
||||
pub(crate) use_cache: bool,
|
||||
#[serde(default)]
|
||||
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
|
||||
@ -35,10 +35,10 @@ impl Config {
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
hidden_act: Activation::Silu,
|
||||
rope_pct: 0.25,
|
||||
partial_rotary_factor: 0.25,
|
||||
rope_theta: 10_000.,
|
||||
max_position_embeddings: 4096,
|
||||
norm_eps: 1e-5,
|
||||
layer_norm_eps: 1e-5,
|
||||
use_qkv_bias: false,
|
||||
use_cache: true,
|
||||
use_flash_attn,
|
||||
@ -50,7 +50,7 @@ impl Config {
|
||||
}
|
||||
|
||||
pub fn rotary_ndims(&self) -> usize {
|
||||
(self.head_dim() as f64 * self.rope_pct) as usize
|
||||
(self.head_dim() as f64 * self.partial_rotary_factor) as usize
|
||||
}
|
||||
|
||||
pub fn num_kv_groups(&self) -> usize {
|
||||
@ -317,10 +317,10 @@ impl DecoderLayer {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.norm_eps,
|
||||
cfg.layer_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
@ -372,7 +372,7 @@ impl Model {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
|
||||
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
|
Reference in New Issue
Block a user