Add StableLM-2, StableLM Code and Zephyr variants (#1650)

* Add StableLM Code and Zephyr variants

* Add V2 models

* Update README
This commit is contained in:
Jani Monoses
2024-02-03 15:58:41 +02:00
committed by GitHub
parent dfab45e1c8
commit d32abbce53
3 changed files with 77 additions and 16 deletions

View File

@ -8,6 +8,11 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
Note that this model is gated so you will have to request access on the Hub in Note that this model is gated so you will have to request access on the Hub in
order to be able to use it. order to be able to use it.
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
and pass it via the --tokenizer-file argument.
## Running some example ## Running some example
```bash ```bash

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src; extern crate accelerate_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::{Parser, ValueEnum};
use candle_transformers::models::quantized_stable_lm::Model as QStableLM; use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM}; use candle_transformers::models::stable_lm::{Config, Model as StableLM};
@ -122,6 +122,16 @@ impl TextGeneration {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum Which {
V1Orig,
V1,
V1Zephyr,
V2,
V2Zephyr,
Code,
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -155,12 +165,15 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")] #[arg(long)]
model_id: String, model_id: Option<String>,
#[arg(long, default_value = "main")] #[arg(long, default_value = "main")]
revision: String, revision: String,
#[arg(long, default_value = "v1-orig")]
which: Which,
#[arg(long)] #[arg(long)]
tokenizer_file: Option<String>, tokenizer_file: Option<String>,
@ -207,8 +220,20 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let api = Api::new()?; let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
Which::Code => "stabilityai/stable-code-3b".to_string(),
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
},
};
let repo = api.repo(Repo::with_revision( let repo = api.repo(Repo::with_revision(
args.model_id, model_id,
RepoType::Model, RepoType::Model,
args.revision, args.revision,
)); ));
@ -221,19 +246,35 @@ fn main() -> Result<()> {
.split(',') .split(',')
.map(std::path::PathBuf::from) .map(std::path::PathBuf::from)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
None => { None => match (args.which, args.quantized) {
if args.quantized { (Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
vec![repo.get("model-q4k.gguf")?] (Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
} else { anyhow::bail!("Quantized {:?} variant not supported.", args.which)
}
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
vec![repo.get("model.safetensors")?] vec![repo.get("model.safetensors")?]
} }
} (Which::Code, false) => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn); let config = match args.which {
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
config.set_use_flash_attn(args.use_flash_attn);
config
}
};
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];

View File

@ -1,10 +1,11 @@
use crate::models::with_tracing::{linear_no_bias, Linear}; use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder}; use candle_nn::{Activation, LayerNorm, VarBuilder};
use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py // https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config { pub struct Config {
pub(crate) vocab_size: usize, pub(crate) vocab_size: usize,
pub(crate) intermediate_size: usize, pub(crate) intermediate_size: usize,
@ -18,7 +19,10 @@ pub struct Config {
pub(crate) max_position_embeddings: usize, pub(crate) max_position_embeddings: usize,
pub(crate) norm_eps: f64, pub(crate) norm_eps: f64,
pub(crate) use_cache: bool, pub(crate) use_cache: bool,
pub(crate) use_flash_attn: bool, #[serde(default)]
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
#[serde(default)]
pub(crate) use_flash_attn: bool, // Not in config.json
} }
impl Config { impl Config {
@ -35,6 +39,7 @@ impl Config {
rope_theta: 10_000., rope_theta: 10_000.,
max_position_embeddings: 4096, max_position_embeddings: 4096,
norm_eps: 1e-5, norm_eps: 1e-5,
use_qkv_bias: false,
use_cache: true, use_cache: true,
use_flash_attn, use_flash_attn,
} }
@ -51,6 +56,10 @@ impl Config {
pub fn num_kv_groups(&self) -> usize { pub fn num_kv_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads self.num_attention_heads / self.num_key_value_heads
} }
pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
self.use_flash_attn = use_flash_attn
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -179,9 +188,15 @@ impl Attention {
let head_dim = cfg.head_dim(); let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads; let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads; let num_kv_heads = cfg.num_key_value_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; let linear_layer = if cfg.use_qkv_bias {
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; linear
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; } else {
linear_no_bias
};
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self { Ok(Self {
q_proj, q_proj,