mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add StableLM-2, StableLM Code and Zephyr variants (#1650)
* Add StableLM Code and Zephyr variants * Add V2 models * Update README
This commit is contained in:
@ -8,6 +8,11 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
|
|||||||
Note that this model is gated so you will have to request access on the Hub in
|
Note that this model is gated so you will have to request access on the Hub in
|
||||||
order to be able to use it.
|
order to be able to use it.
|
||||||
|
|
||||||
|
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
||||||
|
|
||||||
|
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
||||||
|
and pass it via the --tokenizer-file argument.
|
||||||
|
|
||||||
## Running some example
|
## Running some example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||||
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||||
@ -122,6 +122,16 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||||
|
enum Which {
|
||||||
|
V1Orig,
|
||||||
|
V1,
|
||||||
|
V1Zephyr,
|
||||||
|
V2,
|
||||||
|
V2Zephyr,
|
||||||
|
Code,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -155,12 +165,15 @@ struct Args {
|
|||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
#[arg(long)]
|
||||||
model_id: String,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
#[arg(long, default_value = "main")]
|
||||||
revision: String,
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "v1-orig")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer_file: Option<String>,
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
@ -207,8 +220,20 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => match args.which {
|
||||||
|
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
|
||||||
|
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
|
||||||
|
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
|
||||||
|
Which::Code => "stabilityai/stable-code-3b".to_string(),
|
||||||
|
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
|
||||||
|
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
args.model_id,
|
model_id,
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
args.revision,
|
args.revision,
|
||||||
));
|
));
|
||||||
@ -221,19 +246,35 @@ fn main() -> Result<()> {
|
|||||||
.split(',')
|
.split(',')
|
||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => {
|
None => match (args.which, args.quantized) {
|
||||||
if args.quantized {
|
(Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
|
||||||
vec![repo.get("model-q4k.gguf")?]
|
(Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
|
||||||
} else {
|
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
|
||||||
|
}
|
||||||
|
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
|
||||||
vec![repo.get("model.safetensors")?]
|
vec![repo.get("model.safetensors")?]
|
||||||
}
|
}
|
||||||
}
|
(Which::Code, false) => {
|
||||||
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
let config = match args.which {
|
||||||
|
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
|
||||||
|
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let mut config: Config = serde_json::from_str(&config)?;
|
||||||
|
config.set_use_flash_attn(args.use_flash_attn);
|
||||||
|
config
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
use crate::models::with_tracing::{linear_no_bias, Linear};
|
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, LayerNorm, VarBuilder};
|
use candle_nn::{Activation, LayerNorm, VarBuilder};
|
||||||
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub(crate) vocab_size: usize,
|
pub(crate) vocab_size: usize,
|
||||||
pub(crate) intermediate_size: usize,
|
pub(crate) intermediate_size: usize,
|
||||||
@ -18,7 +19,10 @@ pub struct Config {
|
|||||||
pub(crate) max_position_embeddings: usize,
|
pub(crate) max_position_embeddings: usize,
|
||||||
pub(crate) norm_eps: f64,
|
pub(crate) norm_eps: f64,
|
||||||
pub(crate) use_cache: bool,
|
pub(crate) use_cache: bool,
|
||||||
pub(crate) use_flash_attn: bool,
|
#[serde(default)]
|
||||||
|
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
|
||||||
|
#[serde(default)]
|
||||||
|
pub(crate) use_flash_attn: bool, // Not in config.json
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -35,6 +39,7 @@ impl Config {
|
|||||||
rope_theta: 10_000.,
|
rope_theta: 10_000.,
|
||||||
max_position_embeddings: 4096,
|
max_position_embeddings: 4096,
|
||||||
norm_eps: 1e-5,
|
norm_eps: 1e-5,
|
||||||
|
use_qkv_bias: false,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
@ -51,6 +56,10 @@ impl Config {
|
|||||||
pub fn num_kv_groups(&self) -> usize {
|
pub fn num_kv_groups(&self) -> usize {
|
||||||
self.num_attention_heads / self.num_key_value_heads
|
self.num_attention_heads / self.num_key_value_heads
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -179,9 +188,15 @@ impl Attention {
|
|||||||
let head_dim = cfg.head_dim();
|
let head_dim = cfg.head_dim();
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
let linear_layer = if cfg.use_qkv_bias {
|
||||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
linear
|
||||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
} else {
|
||||||
|
linear_no_bias
|
||||||
|
};
|
||||||
|
|
||||||
|
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
q_proj,
|
q_proj,
|
||||||
|
Reference in New Issue
Block a user