mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add StableLM-2, StableLM Code and Zephyr variants (#1650)
* Add StableLM Code and Zephyr variants * Add V2 models * Update README
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
use crate::models::with_tracing::{linear_no_bias, Linear};
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
pub(crate) intermediate_size: usize,
|
||||
@ -18,7 +19,10 @@ pub struct Config {
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) norm_eps: f64,
|
||||
pub(crate) use_cache: bool,
|
||||
pub(crate) use_flash_attn: bool,
|
||||
#[serde(default)]
|
||||
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
|
||||
#[serde(default)]
|
||||
pub(crate) use_flash_attn: bool, // Not in config.json
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -35,6 +39,7 @@ impl Config {
|
||||
rope_theta: 10_000.,
|
||||
max_position_embeddings: 4096,
|
||||
norm_eps: 1e-5,
|
||||
use_qkv_bias: false,
|
||||
use_cache: true,
|
||||
use_flash_attn,
|
||||
}
|
||||
@ -51,6 +56,10 @@ impl Config {
|
||||
pub fn num_kv_groups(&self) -> usize {
|
||||
self.num_attention_heads / self.num_key_value_heads
|
||||
}
|
||||
|
||||
pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
|
||||
self.use_flash_attn = use_flash_attn
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -179,9 +188,15 @@ impl Attention {
|
||||
let head_dim = cfg.head_dim();
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let linear_layer = if cfg.use_qkv_bias {
|
||||
linear
|
||||
} else {
|
||||
linear_no_bias
|
||||
};
|
||||
|
||||
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
|
Reference in New Issue
Block a user