Add a KV cache to T5. (#873)

* Add a KV cache to T5.

* Suggest using release mode.

* Use the kv cache in decoding.

* Add a comment.
This commit is contained in:
Laurent Mazare
2023-09-17 09:00:45 +02:00
committed by GitHub
parent 8658df3485
commit 1a276b5da7
5 changed files with 577 additions and 50 deletions

View File

@ -48,10 +48,6 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
@ -131,6 +127,7 @@ impl T5ModelBuilder {
fn main() -> Result<()> {
let args = Args::parse();
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
@ -142,32 +139,32 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], &builder.device)?.unsqueeze(0)?;
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
if !args.decode {
let model = builder.build_encoder()?;
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&input_token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
let mut model = builder.build_encoder()?;
let start = std::time::Instant::now();
let ys = model.forward(&input_token_ids)?;
println!("{ys}");
println!("Took {:?}", start.elapsed());
} else {
let model = builder.build_conditional_generation()?;
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
let start = std::time::Instant::now();
for _index in 0.. {
for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids =
Tensor::new(&output_token_ids[..], &builder.device)?.unsqueeze(0)?;
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model.forward(&input_token_ids, &decoder_token_ids)?;
let next_token_id = logits_processor.sample(&logits.flatten_to(1)?)?;
if (next_token_id as usize) == builder.config.eos_token_id {
if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
@ -186,7 +183,7 @@ fn main() -> Result<()> {
}
}
None => {
let model = builder.build_encoder()?;
let mut model = builder.build_encoder()?;
let sentences = [
"The cat sits outside",
"A man is playing guitar",