From 87e3a4e175c3c78889fa7463a97256d0365b0327 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 30 Sep 2023 08:07:06 +0200 Subject: [PATCH] Mistral: exit on eos token. (#1001) * Mistral: exit on eos token. * Print the proper stats. * Also add a short flag. --- candle-examples/examples/mistral/main.rs | 17 ++++++++++++----- candle-examples/examples/phi/main.rs | 10 +++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 0c203593..b8c74a2f 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -58,7 +58,11 @@ impl TextGeneration { .get_ids() .to_vec(); - let mut new_tokens = vec![]; + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_vocab(true).get("") { + Some(token) => *token, + None => anyhow::bail!("cannot find the token"), + }; let start_gen = std::time::Instant::now(); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; @@ -80,7 +84,10 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); - new_tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } // TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode` // on each token seems to swallow the whitespaces. } @@ -88,8 +95,8 @@ impl TextGeneration { let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; println!("Generated text:\n{generated_text}"); println!( - "\n{sample_len} tokens generated ({:.2} token/s)", - sample_len as f64 / dt.as_secs_f64(), + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), ); Ok(()) } @@ -122,7 +129,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 100)] sample_len: usize, #[arg(long, default_value = "lmz/candle-mistral")] diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index ab37ed5f..eff329ff 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -65,7 +65,7 @@ impl TextGeneration { .get_ids() .to_vec(); - let mut new_tokens = vec![]; + let mut generated_tokens = 0usize; let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { Some(token) => *token, None => anyhow::bail!("cannot find the endoftext token"), @@ -93,7 +93,7 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); - new_tokens.push(next_token); + generated_tokens += 1; if next_token == eos_token { break; } @@ -103,8 +103,8 @@ impl TextGeneration { } let dt = start_gen.elapsed(); println!( - "\n{sample_len} tokens generated ({:.2} token/s)", - sample_len as f64 / dt.as_secs_f64(), + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), ); Ok(()) } @@ -137,7 +137,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 100)] sample_len: usize, #[arg(long, default_value = "microsoft/phi-1_5")]