From 618ecf5e231beb5bd0b1e59a171eb9cb0af95b01 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 22 Apr 2024 17:54:27 +0200 Subject: [PATCH] Better time measurement for the llama example. (#2106) --- candle-examples/examples/llama/main.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index fa30686d..72656295 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -167,7 +167,7 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p); - let start_gen = std::time::Instant::now(); + let mut start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; for index in 0..args.sample_len { @@ -176,6 +176,9 @@ fn main() -> Result<()> { } else { (tokens.len(), 0) }; + if index == 1 { + start_gen = std::time::Instant::now() + } let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = llama.forward(&input, context_index, &mut cache)?; @@ -211,7 +214,7 @@ fn main() -> Result<()> { println!( "\n\n{} tokens generated ({} token/s)\n", token_generated, - token_generated as f64 / dt.as_secs_f64(), + (token_generated - 1) as f64 / dt.as_secs_f64(), ); Ok(()) }