Add StableLM-2, StableLM Code and Zephyr variants (#1650)

* Add StableLM Code and Zephyr variants

* Add V2 models

* Update README
This commit is contained in:
Jani Monoses
2024-02-03 15:58:41 +02:00
committed by GitHub
parent dfab45e1c8
commit d32abbce53
3 changed files with 77 additions and 16 deletions

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use clap::{Parser, ValueEnum};
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
@ -122,6 +122,16 @@ impl TextGeneration {
}
}
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum Which {
V1Orig,
V1,
V1Zephyr,
V2,
V2Zephyr,
Code,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -155,12 +165,15 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
model_id: String,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "v1-orig")]
which: Which,
#[arg(long)]
tokenizer_file: Option<String>,
@ -207,8 +220,20 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
Which::Code => "stabilityai/stable-code-3b".to_string(),
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
args.model_id,
model_id,
RepoType::Model,
args.revision,
));
@ -221,19 +246,35 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
None => match (args.which, args.quantized) {
(Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
(Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
}
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
vec![repo.get("model.safetensors")?]
}
}
(Which::Code, false) => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let config = match args.which {
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
config.set_use_flash_attn(args.use_flash_attn);
config
}
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];