Add StableLM-2, StableLM Code and Zephyr variants (#1650)

* Add StableLM Code and Zephyr variants

* Add V2 models

* Update README
This commit is contained in:
Jani Monoses
2024-02-03 15:58:41 +02:00
committed by GitHub
parent dfab45e1c8
commit d32abbce53
3 changed files with 77 additions and 16 deletions

View File

@ -1,10 +1,11 @@
use crate::models::with_tracing::{linear_no_bias, Linear};
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) intermediate_size: usize,
@ -18,7 +19,10 @@ pub struct Config {
pub(crate) max_position_embeddings: usize,
pub(crate) norm_eps: f64,
pub(crate) use_cache: bool,
pub(crate) use_flash_attn: bool,
#[serde(default)]
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
#[serde(default)]
pub(crate) use_flash_attn: bool, // Not in config.json
}
impl Config {
@ -35,6 +39,7 @@ impl Config {
rope_theta: 10_000.,
max_position_embeddings: 4096,
norm_eps: 1e-5,
use_qkv_bias: false,
use_cache: true,
use_flash_attn,
}
@ -51,6 +56,10 @@ impl Config {
pub fn num_kv_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}
pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
self.use_flash_attn = use_flash_attn
}
}
#[derive(Debug)]
@ -179,9 +188,15 @@ impl Attention {
let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let linear_layer = if cfg.use_qkv_bias {
linear
} else {
linear_no_bias
};
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,