mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Mistral: print the generated text. (#992)
This commit is contained in:
@ -50,7 +50,6 @@ impl TextGeneration {
|
|||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
let mut tokens = self
|
let mut tokens = self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
@ -82,11 +81,12 @@ impl TextGeneration {
|
|||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
// TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
|
||||||
print!("{token}");
|
// on each token seems to swallow the whitespaces.
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
|
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||||
|
println!("Generated text:\n{generated_text}");
|
||||||
println!(
|
println!(
|
||||||
"\n{sample_len} tokens generated ({:.2} token/s)",
|
"\n{sample_len} tokens generated ({:.2} token/s)",
|
||||||
sample_len as f64 / dt.as_secs_f64(),
|
sample_len as f64 / dt.as_secs_f64(),
|
||||||
|
Reference in New Issue
Block a user