mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a KV cache to marian decoding. (#1226)
This commit is contained in:
@ -8,6 +8,7 @@ use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::marian;
|
||||
|
||||
@ -87,6 +88,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = {
|
||||
@ -107,7 +109,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
let model = marian::MTModel::new(&config, vb)?;
|
||||
let mut model = marian::MTModel::new(&config, vb)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
@ -125,23 +127,26 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let mut token_ids = vec![config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
// TODO: Add a kv cache.
|
||||
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
println!("{token}");
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"{}",
|
||||
tokenizer_dec.decode(&token_ids, true).map_err(E::msg)?
|
||||
);
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user