Mistral: exit on eos token. (#1001)

* Mistral: exit on eos token.

* Print the proper stats.

* Also add a short flag.
This commit is contained in:
Laurent Mazare
2023-09-30 08:07:06 +02:00
committed by GitHub
parent 6203ced495
commit 87e3a4e175
2 changed files with 17 additions and 10 deletions

View File

@ -58,7 +58,11 @@ impl TextGeneration {
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut new_tokens = vec![]; let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..sample_len { for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() }; let context_size = if index > 0 { 1 } else { tokens.len() };
@ -80,7 +84,10 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); generated_tokens += 1;
if next_token == eos_token {
break;
}
// TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode` // TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
// on each token seems to swallow the whitespaces. // on each token seems to swallow the whitespaces.
} }
@ -88,8 +95,8 @@ impl TextGeneration {
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
println!("Generated text:\n{generated_text}"); println!("Generated text:\n{generated_text}");
println!( println!(
"\n{sample_len} tokens generated ({:.2} token/s)", "\n{generated_tokens} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(), generated_tokens as f64 / dt.as_secs_f64(),
); );
Ok(()) Ok(())
} }
@ -122,7 +129,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "lmz/candle-mistral")] #[arg(long, default_value = "lmz/candle-mistral")]

View File

@ -65,7 +65,7 @@ impl TextGeneration {
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut new_tokens = vec![]; let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token, Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"), None => anyhow::bail!("cannot find the endoftext token"),
@ -93,7 +93,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); generated_tokens += 1;
if next_token == eos_token { if next_token == eos_token {
break; break;
} }
@ -103,8 +103,8 @@ impl TextGeneration {
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(
"\n{sample_len} tokens generated ({:.2} token/s)", "\n{generated_tokens} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(), generated_tokens as f64 / dt.as_secs_f64(),
); );
Ok(()) Ok(())
} }
@ -137,7 +137,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "microsoft/phi-1_5")] #[arg(long, default_value = "microsoft/phi-1_5")]