From 50be8a98ba08295ec3ff46d0a779937bc06d369e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 4 Feb 2024 11:57:05 +0100 Subject: [PATCH] Quantized support for stable-lm2. (#1654) * Quantized support for stable-lm2. * Quantized support for v2-zephyr. --- candle-examples/examples/stable-lm/README.md | 4 ++- candle-examples/examples/stable-lm/main.rs | 29 +++++++++++++++---- .../src/models/quantized_stable_lm.rs | 13 ++++++--- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/candle-examples/examples/stable-lm/README.md b/candle-examples/examples/stable-lm/README.md index 485812d3..546124a2 100644 --- a/candle-examples/examples/stable-lm/README.md +++ b/candle-examples/examples/stable-lm/README.md @@ -10,7 +10,9 @@ order to be able to use it. Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants. -StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true) +StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by +Candle, so to run it you can download a somewhat compatible +[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true) and pass it via the --tokenizer-file argument. ## Running some example diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 415c6e7e..abe7020c 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -162,7 +162,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, short = 'n', default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 1000)] sample_len: usize, #[arg(long)] @@ -171,7 +171,7 @@ struct Args { #[arg(long, default_value = "main")] revision: String, - #[arg(long, default_value = "v1-orig")] + #[arg(long, default_value = "v2")] which: Which, #[arg(long)] @@ -239,7 +239,14 @@ fn main() -> Result<()> { )); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), - None => repo.get("tokenizer.json")?, + None => match args.which { + Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => { + repo.get("tokenizer.json")? + } + Which::V2 | Which::V2Zephyr => api + .model("lmz/candle-stablelm".to_string()) + .get("tokenizer-gpt4.json")?, + }, }; let filenames = match args.weight_files { Some(files) => files @@ -247,8 +254,20 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::>(), None => match (args.which, args.quantized) { - (Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?], - (Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => { + (Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?], + (Which::V2, true) => { + let gguf = api + .model("lmz/candle-stablelm".to_string()) + .get("stablelm-2-1_6b-q4k.gguf")?; + vec![gguf] + } + (Which::V2Zephyr, true) => { + let gguf = api + .model("lmz/candle-stablelm".to_string()) + .get("stablelm-2-zephyr-1_6b-q4k.gguf")?; + vec![gguf] + } + (Which::V1Zephyr | Which::Code, true) => { anyhow::bail!("Quantized {:?} variant not supported.", args.which) } (Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => { diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index 94c96201..fa301f6c 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -1,4 +1,4 @@ -use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear}; +use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear}; pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm}; @@ -67,9 +67,14 @@ impl Attention { let head_dim = cfg.head_dim(); let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let linear_layer = if cfg.use_qkv_bias { + linear + } else { + linear_no_bias + }; + let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; Ok(Self { q_proj,