Mistral: exit on eos token. (#1001)

* Mistral: exit on eos token.

* Print the proper stats.

* Also add a short flag.
This commit is contained in:
Laurent Mazare
2023-09-30 08:07:06 +02:00
committed by GitHub
parent 6203ced495
commit 87e3a4e175
2 changed files with 17 additions and 10 deletions

View File

@ -58,7 +58,11 @@ impl TextGeneration {
.get_ids()
.to_vec();
let mut new_tokens = vec![];
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@ -80,7 +84,10 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
// TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
// on each token seems to swallow the whitespaces.
}
@ -88,8 +95,8 @@ impl TextGeneration {
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
println!("Generated text:\n{generated_text}");
println!(
"\n{sample_len} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(),
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
@ -122,7 +129,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "lmz/candle-mistral")]

View File

@ -65,7 +65,7 @@ impl TextGeneration {
.get_ids()
.to_vec();
let mut new_tokens = vec![];
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
@ -93,7 +93,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
@ -103,8 +103,8 @@ impl TextGeneration {
}
let dt = start_gen.elapsed();
println!(
"\n{sample_len} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(),
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
@ -137,7 +137,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "microsoft/phi-1_5")]