Streaming mode for reporting the generated tokens (#1007)

* Token streaming.

* Use the token output stream.

* Flush the output.

* Ensure that the last characters get reported.
This commit is contained in:
Laurent Mazare
2023-09-30 16:04:11 +02:00
committed by GitHub
parent 4021272875
commit 06207332bc
4 changed files with 96 additions and 11 deletions

View File

@ -10,6 +10,7 @@ use clap::Parser;
use candle_transformers::models::mistral::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -18,7 +19,7 @@ use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
@ -39,7 +40,7 @@ impl TextGeneration {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
@ -49,18 +50,24 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
std::io::stdout().flush()?;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
Some(token) => *token,
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
@ -88,12 +95,15 @@ impl TextGeneration {
if next_token == eos_token {
break;
}
// TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
// on each token seems to swallow the whitespaces.
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
println!("Generated text:\n{generated_text}");
let rest = self.tokenizer.decode_rest().map_err(E::msg)?;
print!("{rest}");
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),