mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a quantized test that use negative values. (#470)
* Add a quantized test that use negative values. * Add a default tokenizer.
This commit is contained in:
@ -1,5 +1,7 @@
|
|||||||
use candle_core::{quantized, Device, Result, Tensor};
|
use candle_core::{quantized, Device, Result, Tensor};
|
||||||
use quantized::{k_quants, GgmlType};
|
use quantized::{k_quants, GgmlType};
|
||||||
|
mod test_utils;
|
||||||
|
use test_utils::to_vec2_round;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn quantized_matmul() -> Result<()> {
|
fn quantized_matmul() -> Result<()> {
|
||||||
@ -45,6 +47,54 @@ fn quantized_matmul() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn quantized_matmul_neg() -> Result<()> {
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
|
let (m, k, n) = (3, 64, 4);
|
||||||
|
let lhs = (0..(m * k))
|
||||||
|
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||||
|
let mut dst = vec![42.; 3 * 4];
|
||||||
|
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||||
|
let rhs = (0..k * n)
|
||||||
|
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||||
|
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||||
|
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||||
|
assert_eq!(
|
||||||
|
dst,
|
||||||
|
&[
|
||||||
|
243524.14, -19596.34, -285051.3, -549814.94, 23776.629, 21650.926, 19397.924,
|
||||||
|
18366.586, -196472.1, 63011.6, 324584.56, 587901.9
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec2_round(&mm, 0)?,
|
||||||
|
&[
|
||||||
|
[244064.0, -20128.0, -284320.0, -548512.0],
|
||||||
|
[23563.0, 21515.0, 19467.0, 17419.0],
|
||||||
|
[-196939.0, 63157.0, 323253.0, 583349.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
let qtensor = quantized::QTensor::new(rhs_t, (4, 64));
|
||||||
|
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||||
|
let res = matmul.forward(&tensor_lhs)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec2_round(&res, 0)?,
|
||||||
|
&[
|
||||||
|
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||||
|
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||||
|
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn quantize_q4_0() -> Result<()> {
|
fn quantize_q4_0() -> Result<()> {
|
||||||
use k_quants::BlockQ4_0;
|
use k_quants::BlockQ4_0;
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use candle::quantized::ggml_file::Content;
|
use candle::quantized::ggml_file::Content;
|
||||||
use candle::quantized::{QMatMul, QTensor};
|
use candle::quantized::{QMatMul, QTensor};
|
||||||
@ -259,7 +260,7 @@ struct Args {
|
|||||||
|
|
||||||
/// The tokenizer config in json format.
|
/// The tokenizer config in json format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer: String,
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -270,11 +271,24 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer_path = match &self.tokenizer {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let mut file = std::fs::File::open(args.model)?;
|
let mut file = std::fs::File::open(&args.model)?;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let model = Content::read(&mut file)?;
|
let model = Content::read(&mut file)?;
|
||||||
|
|
||||||
@ -303,7 +317,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let mut model = ModelWeights::new(model)?;
|
let mut model = ModelWeights::new(model)?;
|
||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(args.tokenizer).map_err(anyhow::Error::msg)?;
|
let tokenizer = args.tokenizer()?;
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
@ -312,6 +326,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
let mut token_generated = 0;
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let context_size = if index == 0 { tokens.len() } else { 1 };
|
let context_size = if index == 0 { tokens.len() } else { 1 };
|
||||||
@ -322,6 +338,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
|
token_generated += 1;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
|
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||||
@ -334,5 +351,11 @@ fn main() -> anyhow::Result<()> {
|
|||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n\n{} tokens generated ({} token/s)\n",
|
||||||
|
token_generated,
|
||||||
|
token_generated as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user