Add a KV cache to marian decoding. (#1226)

This commit is contained in:
Laurent Mazare
2023-10-31 09:47:44 +01:00
committed by GitHub
parent 7d0202710b
commit c12ad45562
3 changed files with 55 additions and 24 deletions

View File

@ -8,6 +8,7 @@ use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
@ -87,6 +88,7 @@ pub fn main() -> anyhow::Result<()> {
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
@ -107,7 +109,7 @@ pub fn main() -> anyhow::Result<()> {
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
let model = marian::MTModel::new(&config, vb)?;
let mut model = marian::MTModel::new(&config, vb)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
@ -125,23 +127,26 @@ pub fn main() -> anyhow::Result<()> {
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
// TODO: Add a kv cache.
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
println!("{token}");
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
}
println!(
"{}",
tokenizer_dec.decode(&token_ids, true).map_err(E::msg)?
);
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}