fix index_pos bug when kv cache is disabled. (#1517)

* fix index_pos bug when kv cache is disabled

* Tweak the fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
optman
2024-01-06 18:43:01 +08:00
committed by GitHub
parent 8d1a57c9a0
commit 84250bf52f

View File

@ -165,14 +165,14 @@ fn main() -> Result<()> {
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;
for index in 0..args.sample_len { for index in 0..args.sample_len {
let context_size = if cache.use_kv_cache && index > 0 { let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
1 (1, index_pos)
} else { } else {
tokens.len() (tokens.len(), 0)
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?; let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits