diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index babd71a8..9c5168bf 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,5 +1,7 @@ use candle_core::{quantized, Device, Result, Tensor}; use quantized::{k_quants, GgmlType}; +mod test_utils; +use test_utils::to_vec2_round; #[test] fn quantized_matmul() -> Result<()> { @@ -45,6 +47,54 @@ fn quantized_matmul() -> Result<()> { Ok(()) } +#[test] +fn quantized_matmul_neg() -> Result<()> { + let cpu = &Device::Cpu; + let (m, k, n) = (3, 64, 4); + let lhs = (0..(m * k)) + .map(|v| v as f32 - (m * k) as f32 / 2.0) + .collect::>(); + let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; + let mut dst = vec![42.; 3 * 4]; + let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; + let rhs = (0..k * n) + .map(|v| v as f32 - (k * n) as f32 / 3.0) + .collect::>(); + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; + k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; + k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; + assert_eq!( + dst, + &[ + 243524.14, -19596.34, -285051.3, -549814.94, 23776.629, 21650.926, 19397.924, + 18366.586, -196472.1, 63011.6, 324584.56, 587901.9 + ] + ); + let mm = tensor_lhs.matmul(&tensor_rhs)?; + assert_eq!( + to_vec2_round(&mm, 0)?, + &[ + [244064.0, -20128.0, -284320.0, -548512.0], + [23563.0, 21515.0, 19467.0, 17419.0], + [-196939.0, 63157.0, 323253.0, 583349.0] + ] + ); + + let qtensor = quantized::QTensor::new(rhs_t, (4, 64)); + let matmul = quantized::QMatMul::from_qtensor(qtensor); + let res = matmul.forward(&tensor_lhs)?; + assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243524.0, -19596.0, -285051.0, -549815.0], + [23777.0, 21651.0, 19398.0, 18367.0], + [-196472.0, 63012.0, 324585.0, 587902.0] + ] + ); + + Ok(()) +} + #[test] fn quantize_q4_0() -> Result<()> { use k_quants::BlockQ4_0; diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index 20fa94cc..68e2267c 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -2,6 +2,7 @@ use clap::Parser; use std::collections::HashMap; use std::io::Write; +use tokenizers::Tokenizer; use candle::quantized::ggml_file::Content; use candle::quantized::{QMatMul, QTensor}; @@ -259,7 +260,7 @@ struct Args { /// The tokenizer config in json format. #[arg(long)] - tokenizer: String, + tokenizer: Option, /// The temperature used to generate samples. #[arg(long)] @@ -270,11 +271,24 @@ struct Args { seed: u64, } +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("hf-internal-testing/llama-tokenizer".to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } +} + fn main() -> anyhow::Result<()> { - use tokenizers::Tokenizer; let args = Args::parse(); - let mut file = std::fs::File::open(args.model)?; + let mut file = std::fs::File::open(&args.model)?; let start = std::time::Instant::now(); let model = Content::read(&mut file)?; @@ -303,7 +317,7 @@ fn main() -> anyhow::Result<()> { let mut model = ModelWeights::new(model)?; println!("model built"); - let tokenizer = Tokenizer::from_file(args.tokenizer).map_err(anyhow::Error::msg)?; + let tokenizer = args.tokenizer()?; let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer .encode(prompt, true) @@ -312,6 +326,8 @@ fn main() -> anyhow::Result<()> { .to_vec(); let mut index_pos = 0; let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let start_gen = std::time::Instant::now(); + let mut token_generated = 0; print!("{prompt}"); for index in 0..args.sample_len { let context_size = if index == 0 { tokens.len() } else { 1 }; @@ -322,6 +338,7 @@ fn main() -> anyhow::Result<()> { index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; + token_generated += 1; tokens.push(next_token); // Extracting the last token as a string is complicated, here we just apply some simple @@ -334,5 +351,11 @@ fn main() -> anyhow::Result<()> { std::io::stdout().flush()?; } } + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + token_generated, + token_generated as f64 / dt.as_secs_f64(), + ); Ok(()) }