Fix token generation in bilingual models (non-English outputs) (#1668)

Co-authored-by: Guoqing Bao <guoqing.bao@enflame-tech.com>
This commit is contained in:
Guoqing Bao
2024-02-06 19:03:53 +08:00
committed by GitHub
parent b545f54a19
commit 678f64dd27
2 changed files with 2 additions and 1 deletions

View File

@ -104,6 +104,7 @@ impl TextGeneration {
break; break;
} }
if let Some(t) = self.tokenizer.next_token(next_token)? { if let Some(t) = self.tokenizer.next_token(next_token)? {
let t = t.replace("<|im_end|>", "\n");
print!("{t}"); print!("{t}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }

View File

@ -40,7 +40,7 @@ impl TokenOutputStream {
}; };
self.tokens.push(token); self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?; let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() { if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
let text = text.split_at(prev_text.len()); let text = text.split_at(prev_text.len());
self.prev_index = self.current_index; self.prev_index = self.current_index;
self.current_index = self.tokens.len(); self.current_index = self.tokens.len();