From 32544a2ad691fa0fe6e77ade4f5b3232e8d311c1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 27 Feb 2024 11:24:11 +0100 Subject: [PATCH] Add an option to split the prompt. (#1766) --- candle-examples/examples/quantized/main.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 34c44233..a497e944 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -212,6 +212,10 @@ struct Args { #[arg(long)] verbose_prompt: bool, + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, @@ -487,11 +491,20 @@ fn main() -> anyhow::Result<()> { let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); let start_prompt_processing = std::time::Instant::now(); - let mut next_token = { + let mut next_token = if !args.split_prompt { let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let logits = model.forward(&input, 0)?; let logits = logits.squeeze(0)?; logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in prompt_tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token }; let prompt_dt = start_prompt_processing.elapsed(); all_tokens.push(next_token);