mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
fix index_pos bug when kv cache is disabled. (#1517)
* fix index_pos bug when kv cache is disabled * Tweak the fix. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -165,14 +165,14 @@ fn main() -> Result<()> {
|
|||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let context_size = if cache.use_kv_cache && index > 0 {
|
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||||
1
|
(1, index_pos)
|
||||||
} else {
|
} else {
|
||||||
tokens.len()
|
(tokens.len(), 0)
|
||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, index_pos)?;
|
let logits = llama.forward(&input, context_index)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
logits
|
||||||
|
Reference in New Issue
Block a user