mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
add eos token to phi example (#965)
* add eos token to phi example * rustfmt + get the token directly. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -66,6 +66,10 @@ impl TextGeneration {
|
|||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
|
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||||
|
Some(token) => *token,
|
||||||
|
None => anyhow::bail!("cannot find the endoftext token"),
|
||||||
|
};
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
@ -90,6 +94,9 @@ impl TextGeneration {
|
|||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
print!("{token}");
|
print!("{token}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
Reference in New Issue
Block a user