Mistral: exit on eos token. (#1001)

* Mistral: exit on eos token.

* Print the proper stats.

* Also add a short flag.
This commit is contained in:
Laurent Mazare
2023-09-30 08:07:06 +02:00
committed by GitHub
parent 6203ced495
commit 87e3a4e175
2 changed files with 17 additions and 10 deletions

View File

@ -65,7 +65,7 @@ impl TextGeneration {
.get_ids()
.to_vec();
let mut new_tokens = vec![];
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
@ -93,7 +93,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
@ -103,8 +103,8 @@ impl TextGeneration {
}
let dt = start_gen.elapsed();
println!(
"\n{sample_len} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(),
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
@ -137,7 +137,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "microsoft/phi-1_5")]