From 2dd43d6cdd3242bcbe49a0558e56e24549a866d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radam=C3=A9s=20Ajna?= Date: Tue, 26 Sep 2023 01:21:22 -0700 Subject: [PATCH] add eos token to phi example (#965) * add eos token to phi example * rustfmt + get the token directly. --------- Co-authored-by: laurent --- candle-examples/examples/phi/main.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index fe365e18..ab37ed5f 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -66,6 +66,10 @@ impl TextGeneration { .to_vec(); let mut new_tokens = vec![]; + let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => anyhow::bail!("cannot find the endoftext token"), + }; let start_gen = std::time::Instant::now(); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; @@ -90,6 +94,9 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); + if next_token == eos_token { + break; + } let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; print!("{token}"); std::io::stdout().flush()?;