Quantized support for stable-lm2. (#1654)

* Quantized support for stable-lm2.

* Quantized support for v2-zephyr.
This commit is contained in:
Laurent Mazare
2024-02-04 11:57:05 +01:00
committed by GitHub
parent 58cc896e69
commit 50be8a98ba
3 changed files with 36 additions and 10 deletions

View File

@ -10,7 +10,9 @@ order to be able to use it.
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
Candle, so to run it you can download a somewhat compatible
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
and pass it via the --tokenizer-file argument.
## Running some example

View File

@ -162,7 +162,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
#[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize,
#[arg(long)]
@ -171,7 +171,7 @@ struct Args {
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "v1-orig")]
#[arg(long, default_value = "v2")]
which: Which,
#[arg(long)]
@ -239,7 +239,14 @@ fn main() -> Result<()> {
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
None => match args.which {
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
repo.get("tokenizer.json")?
}
Which::V2 | Which::V2Zephyr => api
.model("lmz/candle-stablelm".to_string())
.get("tokenizer-gpt4.json")?,
},
};
let filenames = match args.weight_files {
Some(files) => files
@ -247,8 +254,20 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match (args.which, args.quantized) {
(Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
(Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
(Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?],
(Which::V2, true) => {
let gguf = api
.model("lmz/candle-stablelm".to_string())
.get("stablelm-2-1_6b-q4k.gguf")?;
vec![gguf]
}
(Which::V2Zephyr, true) => {
let gguf = api
.model("lmz/candle-stablelm".to_string())
.get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
vec![gguf]
}
(Which::V1Zephyr | Which::Code, true) => {
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
}
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {