Mistral: print the generated text. (#992)

This commit is contained in:
Laurent Mazare
2023-09-29 11:56:11 +02:00
committed by GitHub
parent 01b92cd959
commit 6f17ef82be

View File

@ -50,7 +50,6 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
@ -82,11 +81,12 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
// TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
// on each token seems to swallow the whitespaces.
}
let dt = start_gen.elapsed();
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
println!("Generated text:\n{generated_text}");
println!(
"\n{sample_len} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(),