mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Mistral: print the generated text. (#992)
This commit is contained in:
@ -50,7 +50,6 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
@ -82,11 +81,12 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
// TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
|
||||
// on each token seems to swallow the whitespaces.
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
println!("Generated text:\n{generated_text}");
|
||||
println!(
|
||||
"\n{sample_len} tokens generated ({:.2} token/s)",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
|
Reference in New Issue
Block a user