mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Explicit caching in llama2.c.
This commit is contained in:
@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use model::{Cache, Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
@ -160,10 +160,10 @@ enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
|
||||
let tokens = match &args.pretokenized_dir {
|
||||
None => {
|
||||
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
}
|
||||
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let (model, config, mut cache) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
};
|
||||
|
||||
println!("starting the inference loop");
|
||||
@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = model.forward(&input, index_pos, &mut cache)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
logits
|
||||
|
Reference in New Issue
Block a user