mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Streaming mode for reporting the generated tokens (#1007)
* Token streaming. * Use the token output stream. * Flush the output. * Ensure that the last characters get reported.
This commit is contained in:
74
candle-examples/src/token_output_stream.rs
Normal file
74
candle-examples/src/token_output_stream.rs
Normal file
@ -0,0 +1,74 @@
|
||||
use candle::Result;
|
||||
|
||||
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
|
||||
/// streaming way rather than having to wait for the full decoding.
|
||||
pub struct TokenOutputStream {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
tokens: Vec<u32>,
|
||||
prev_index: usize,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl TokenOutputStream {
|
||||
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
tokens: Vec::new(),
|
||||
prev_index: 0,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> tokenizers::Tokenizer {
|
||||
self.tokenizer
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
match self.tokenizer.decode(tokens, true) {
|
||||
Ok(str) => Ok(str),
|
||||
Err(err) => candle::bail!("cannot decode: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
|
||||
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_rest(&self) -> Result<String> {
|
||||
self.decode(&self.tokens[self.prev_index..])
|
||||
}
|
||||
|
||||
pub fn decode_all(&self) -> Result<String> {
|
||||
self.decode(&self.tokens)
|
||||
}
|
||||
|
||||
pub fn get_token(&self, token_s: &str) -> Option<u32> {
|
||||
self.tokenizer.get_vocab(true).get(token_s).copied()
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.tokens.clear();
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user