mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a quantized test that use negative values. (#470)
* Add a quantized test that use negative values. * Add a default tokenizer.
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::ggml_file::Content;
|
||||
use candle::quantized::{QMatMul, QTensor};
|
||||
@ -259,7 +260,7 @@ struct Args {
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: String,
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
@ -270,11 +271,24 @@ struct Args {
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
let args = Args::parse();
|
||||
|
||||
let mut file = std::fs::File::open(args.model)?;
|
||||
let mut file = std::fs::File::open(&args.model)?;
|
||||
let start = std::time::Instant::now();
|
||||
let model = Content::read(&mut file)?;
|
||||
|
||||
@ -303,7 +317,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let mut model = ModelWeights::new(model)?;
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = Tokenizer::from_file(args.tokenizer).map_err(anyhow::Error::msg)?;
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -312,6 +326,8 @@ fn main() -> anyhow::Result<()> {
|
||||
.to_vec();
|
||||
let mut index_pos = 0;
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut token_generated = 0;
|
||||
print!("{prompt}");
|
||||
for index in 0..args.sample_len {
|
||||
let context_size = if index == 0 { tokens.len() } else { 1 };
|
||||
@ -322,6 +338,7 @@ fn main() -> anyhow::Result<()> {
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
@ -334,5 +351,11 @@ fn main() -> anyhow::Result<()> {
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user