From 678f64dd275d1f34b0d46f9fb3dbbd00337d1426 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 6 Feb 2024 19:03:53 +0800 Subject: [PATCH] Fix token generation in bilingual models (non-English outputs) (#1668) Co-authored-by: Guoqing Bao --- candle-examples/examples/yi/main.rs | 1 + candle-examples/src/token_output_stream.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/candle-examples/examples/yi/main.rs b/candle-examples/examples/yi/main.rs index e4cbfc6f..2c51586a 100644 --- a/candle-examples/examples/yi/main.rs +++ b/candle-examples/examples/yi/main.rs @@ -104,6 +104,7 @@ impl TextGeneration { break; } if let Some(t) = self.tokenizer.next_token(next_token)? { + let t = t.replace("<|im_end|>", "\n"); print!("{t}"); std::io::stdout().flush()?; } diff --git a/candle-examples/src/token_output_stream.rs b/candle-examples/src/token_output_stream.rs index 907d8ddd..07f33620 100644 --- a/candle-examples/src/token_output_stream.rs +++ b/candle-examples/src/token_output_stream.rs @@ -40,7 +40,7 @@ impl TokenOutputStream { }; self.tokens.push(token); let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() { + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() { let text = text.split_at(prev_text.len()); self.prev_index = self.current_index; self.current_index = self.tokens.len();