Mistral: print the generated text. (#992)

This commit is contained in:
Laurent Mazare
2023-09-29 11:56:11 +02:00
committed by GitHub
parent 01b92cd959
commit 6f17ef82be

View File

@ -50,7 +50,6 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write; use std::io::Write;
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
let mut tokens = self let mut tokens = self
.tokenizer .tokenizer
@ -82,11 +81,12 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; // TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
print!("{token}"); // on each token seems to swallow the whitespaces.
std::io::stdout().flush()?;
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
println!("Generated text:\n{generated_text}");
println!( println!(
"\n{sample_len} tokens generated ({:.2} token/s)", "\n{sample_len} tokens generated ({:.2} token/s)",
sample_len as f64 / dt.as_secs_f64(), sample_len as f64 / dt.as_secs_f64(),