add eos token to phi example (#965)

* add eos token to phi example

* rustfmt + get the token directly.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Radamés Ajna
2023-09-26 01:21:22 -07:00
committed by GitHub
parent 1fcac4afed
commit 2dd43d6cdd

View File

@ -66,6 +66,10 @@ impl TextGeneration {
.to_vec(); .to_vec();
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..sample_len { for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() }; let context_size = if index > 0 { 1 } else { tokens.len() };
@ -90,6 +94,9 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}"); print!("{token}");
std::io::stdout().flush()?; std::io::stdout().flush()?;