Add a KV cache to falcon. (#104)

This commit is contained in:
Laurent Mazare
2023-07-07 17:24:38 +01:00
committed by GitHub
parent 05ff1cff66
commit e923b3adc2
3 changed files with 80 additions and 43 deletions

View File

@ -51,7 +51,13 @@ impl TextGeneration {
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let start_gen = std::time::Instant::now();
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let context_size = if self.model.config().use_cache && index > 0 {
1
} else {
tokens.len()
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;