mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add a KV cache to falcon. (#104)
This commit is contained in:
@ -51,7 +51,13 @@ impl TextGeneration {
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let context_size = if self.model.config().use_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
|
Reference in New Issue
Block a user