mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fix the verbose prompt for phi. (#1097)
This commit is contained in:
@ -59,9 +59,10 @@ impl TextGeneration {
|
|||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
if tokens.is_empty() {
|
||||||
|
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||||
|
}
|
||||||
if self.verbose_prompt {
|
if self.verbose_prompt {
|
||||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
@ -74,6 +75,8 @@ impl TextGeneration {
|
|||||||
Some(token) => *token,
|
Some(token) => *token,
|
||||||
None => anyhow::bail!("cannot find the endoftext token"),
|
None => anyhow::bail!("cannot find the endoftext token"),
|
||||||
};
|
};
|
||||||
|
print!("{prompt}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
Reference in New Issue
Block a user