add quantized qwen2 (#2329)

* add quantized version of qwen2 and corresponding example for qwen2-instruct

* fix quantized qwen2 clippy error
This commit is contained in:
Zhuo Jinggang
2024-07-12 16:00:03 +08:00
committed by GitHub
parent a226a9736b
commit c63048d374
4 changed files with 641 additions and 0 deletions

View File

@ -0,0 +1,11 @@
# candle-quantized-qwen2-instruct
[Qwen2]((https://qwenlm.github.io/blog/qwen2/)) is an upgraded version of Qwen1.5, released by Alibaba Cloud.
## Running the example
```bash
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
```
0.5b, 1.5b, 7b and 72b models are available via `--model` argument.

View File

@ -0,0 +1,306 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_qwen2::ModelWeights as Qwen2;
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "0.5b")]
W2_0_5b,
#[value(name = "1.5b")]
W2_1_5b,
#[value(name = "7b")]
W2_7b,
#[value(name = "72b")]
W2_72b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "0.5b")]
which: Which,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::W2_0_5b => "Qwen/Qwen2-0.5B-Instruct",
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename, revision) = match self.which {
Which::W2_0_5b => (
"Qwen/Qwen2-0.5B-Instruct-GGUF",
"qwen2-0_5b-instruct-q4_0.gguf",
"main",
),
Which::W2_1_5b => (
"Qwen/Qwen2-1.5B-Instruct-GGUF",
"qwen2-1_5b-instruct-q4_0.gguf",
"main",
),
Which::W2_7b => (
"Qwen/Qwen2-7B-Instruct-GGUF",
"qwen2-7b-instruct-q4_0.gguf",
"main",
),
Which::W2_72b => (
"Qwen/Qwen2-72B-Instruct-GGUF",
"qwen2-72b-instruct-q4_0.gguf",
"main",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
Qwen2::from_gguf(model, &mut file, &device)?
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let prompt_str = format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt_str
);
print!("formatted instruct prompt: {}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let tokens = tokens.get_ids();
let to_sample = args.sample_len.saturating_sub(1);
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens.len(),
tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -47,6 +47,7 @@ pub mod quantized_moondream;
pub mod quantized_mpt;
pub mod quantized_phi;
pub mod quantized_phi3;
pub mod quantized_qwen2;
pub mod quantized_recurrent_gemma;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;

View File

@ -0,0 +1,323 @@
use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
use candle::{
quantized::{gguf_file, QMatMul},
DType, Device, IndexOp, Result, Tensor,
};
use candle_nn::{Embedding, Module};
use std::collections::HashMap;
#[derive(Debug, Clone)]
struct Mlp {
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w1 = self.feed_forward_w1.forward(xs)?;
let w3 = self.feed_forward_w3.forward(xs)?;
self.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_bq: Tensor,
attention_bk: Tensor,
attention_bv: Tensor,
attention_wo: QMatMul,
attention_norm: RmsNorm,
mlp: Mlp,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
span_mlp: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
Ok(m)
}
impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
candle_nn::rotary_emb::rope(&x.contiguous()?, &cos, &sin)
}
fn forward_attn(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
index_pos: usize,
) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;
let q = q.broadcast_add(&self.attention_bq)?;
let k = k.broadcast_add(&self.attention_bk)?;
let v = v.broadcast_add(&self.attention_bv)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
// let (q, k) = self
// .rotary_embedding
// .apply_rotary_emb_qkv(&q, &k, index_pos)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?;
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));
// Support for MQA, useful for 70B models and mistral.
let k = repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
}
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
span_output: tracing::Span,
}
fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
context_length: usize,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, context_length as u32, device)?
.to_dtype(DType::F32)?
.reshape((context_length, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
let head_count = md_get("qwen2.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("qwen2.attention.head_count_kv")?.to_u32()? as usize;
let embedding_length = md_get("qwen2.embedding_length")?.to_u32()? as usize;
let context_length = md_get("qwen2.context_length")?.to_u32()? as usize;
let block_count = md_get("qwen2.block_count")?.to_u32()? as usize;
let rms_norm_eps = md_get("qwen2.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("qwen2.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let head_dim = embedding_length / head_count;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = match ct.tensor(reader, "output.weight", device) {
Ok(v) => QMatMul::from_qtensor(v)?,
_ => {
// use tie_word_embeddings
QMatMul::from_qtensor(ct.tensor(reader, "token_embd.weight", device)?)?
}
};
let (cos, sin) = precomput_freqs_cis(head_dim, rope_freq_base, context_length, device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_bq = ct.tensor(reader, &format!("{prefix}.attn_q.bias"), device)?;
let attention_bk = ct.tensor(reader, &format!("{prefix}.attn_k.bias"), device)?;
let attention_bv = ct.tensor(reader, &format!("{prefix}.attn_v.bias"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp = {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
}
};
let attention_norm =
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
layers.push(LayerWeights {
attention_wq: QMatMul::from_qtensor(attention_wq)?,
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_bq: attention_bq.dequantize(device)?,
attention_bk: attention_bk.dequantize(device)?,
attention_bv: attention_bv.dequantize(device)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
cos: cos.clone(),
sin: sin.clone(),
mlp,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim,
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,
span_mlp,
});
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
norm,
output,
masks: HashMap::new(),
span,
span_output,
})
}
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mask = if seq_len == 1 {
None
} else {
Some(self.mask(seq_len, x.device())?)
};
let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() {
let x = layer_in;
let residual = &x;
let x = layer.attention_norm.forward(&x)?;
let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?;
let x = (attn + residual)?;
// MLP
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp.forward(&x)?;
let x = (x + residual)?;
layer_in = x
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;
let _enter = self.span_output.enter();
self.output.forward(&x)
}
}