mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Quantized support for stable-lm2. (#1654)
* Quantized support for stable-lm2. * Quantized support for v2-zephyr.
This commit is contained in:
@ -10,7 +10,9 @@ order to be able to use it.
|
|||||||
|
|
||||||
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
||||||
|
|
||||||
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by Candle, so to run it you can download a somewhat compatible [tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
|
||||||
|
Candle, so to run it you can download a somewhat compatible
|
||||||
|
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
||||||
and pass it via the --tokenizer-file argument.
|
and pass it via the --tokenizer-file argument.
|
||||||
|
|
||||||
## Running some example
|
## Running some example
|
||||||
|
@ -162,7 +162,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -171,7 +171,7 @@ struct Args {
|
|||||||
#[arg(long, default_value = "main")]
|
#[arg(long, default_value = "main")]
|
||||||
revision: String,
|
revision: String,
|
||||||
|
|
||||||
#[arg(long, default_value = "v1-orig")]
|
#[arg(long, default_value = "v2")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -239,7 +239,14 @@ fn main() -> Result<()> {
|
|||||||
));
|
));
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => repo.get("tokenizer.json")?,
|
None => match args.which {
|
||||||
|
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
|
||||||
|
repo.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
Which::V2 | Which::V2Zephyr => api
|
||||||
|
.model("lmz/candle-stablelm".to_string())
|
||||||
|
.get("tokenizer-gpt4.json")?,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
let filenames = match args.weight_files {
|
let filenames = match args.weight_files {
|
||||||
Some(files) => files
|
Some(files) => files
|
||||||
@ -247,8 +254,20 @@ fn main() -> Result<()> {
|
|||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match (args.which, args.quantized) {
|
None => match (args.which, args.quantized) {
|
||||||
(Which::V1Orig, true) => vec![repo.get("model-q4k.gguf")?],
|
(Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?],
|
||||||
(Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code, true) => {
|
(Which::V2, true) => {
|
||||||
|
let gguf = api
|
||||||
|
.model("lmz/candle-stablelm".to_string())
|
||||||
|
.get("stablelm-2-1_6b-q4k.gguf")?;
|
||||||
|
vec![gguf]
|
||||||
|
}
|
||||||
|
(Which::V2Zephyr, true) => {
|
||||||
|
let gguf = api
|
||||||
|
.model("lmz/candle-stablelm".to_string())
|
||||||
|
.get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
|
||||||
|
vec![gguf]
|
||||||
|
}
|
||||||
|
(Which::V1Zephyr | Which::Code, true) => {
|
||||||
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
|
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
|
||||||
}
|
}
|
||||||
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
|
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
|
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, LayerNorm};
|
use candle_nn::{Activation, LayerNorm};
|
||||||
@ -67,9 +67,14 @@ impl Attention {
|
|||||||
let head_dim = cfg.head_dim();
|
let head_dim = cfg.head_dim();
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
let linear_layer = if cfg.use_qkv_bias {
|
||||||
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
linear
|
||||||
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
} else {
|
||||||
|
linear_no_bias
|
||||||
|
};
|
||||||
|
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
q_proj,
|
q_proj,
|
||||||
|
Reference in New Issue
Block a user