mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add StableLM-2, StableLM Code and Zephyr variants (#1650)
* Add StableLM Code and Zephyr variants * Add V2 models * Update README
This commit is contained in:
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||
@ -122,6 +122,16 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
enum Which {
|
||||
V1Orig,
|
||||
V1,
|
||||
V1Zephyr,
|
||||
V2,
|
||||
V2Zephyr,
|
||||
Code,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -155,12 +165,15 @@ struct Args {
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
||||
model_id: String,
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long, default_value = "v1-orig")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
@ -207,8 +220,20 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.which {
|
||||
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
|
||||
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
|
||||
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
|
||||
Which::Code => "stabilityai/stable-code-3b".to_string(),
|
||||
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
|
||||
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
@ -221,19 +246,35 @@ fn main() -> Result<()> {
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![repo.get("model-q4k.gguf")?]
|
||||
} else {
|
||||
None => match (args.which, args.quantized) {
|
||||
(Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
|
||||
(Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
|
||||
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
|
||||
}
|
||||
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
(Which::Code, false) => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let config = match args.which {
|
||||
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
|
||||
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let mut config: Config = serde_json::from_str(&config)?;
|
||||
config.set_use_flash_attn(args.use_flash_attn);
|
||||
config
|
||||
}
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
|
Reference in New Issue
Block a user