mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Streaming mode for reporting the generated tokens (#1007)
* Token streaming. * Use the token output stream. * Flush the output. * Ensure that the last characters get reported.
This commit is contained in:
@ -10,6 +10,7 @@ use clap::Parser;
|
||||
use candle_transformers::models::mistral::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -18,7 +19,7 @@ use tokenizers::Tokenizer;
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
@ -39,7 +40,7 @@ impl TextGeneration {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
@ -49,18 +50,24 @@ impl TextGeneration {
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
std::io::stdout().flush()?;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
||||
Some(token) => *token,
|
||||
let eos_token = match self.tokenizer.get_token("</s>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
@ -88,12 +95,15 @@ impl TextGeneration {
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
// TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
|
||||
// on each token seems to swallow the whitespaces.
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
||||
println!("Generated text:\n{generated_text}");
|
||||
let rest = self.tokenizer.decode_rest().map_err(E::msg)?;
|
||||
print!("{rest}");
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
|
Reference in New Issue
Block a user