mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a quantized variant of llama2.c (#1197)
* Add a quantized variant of llama2.c * Clippy fixes.
This commit is contained in:
@ -7,6 +7,7 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod model;
|
||||
mod qmodel;
|
||||
mod training;
|
||||
mod weights;
|
||||
use clap::{Parser, Subcommand};
|
||||
@ -19,6 +20,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum Model {
|
||||
Llama(Llama),
|
||||
QLlama(QLlama),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
@ -241,24 +257,56 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (vb, config) = if is_safetensors {
|
||||
let (model, config) = if is_gguf {
|
||||
let config = Config::tiny();
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||
let freq_cis_real = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
("freq_cis_real".to_string(), freq_cis_real),
|
||||
("freq_cis_imag".to_string(), freq_cis_imag),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&candle::Device::Cpu,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
(vb, config)
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
(vb, config)
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
@ -273,7 +321,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
if tokens.len() >= model.config.seq_len {
|
||||
if tokens.len() >= config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
|
Reference in New Issue
Block a user