Improve the reshape error messages. (#1096)

* Improve the reshape error messages.

* Add the verbose-prompt flag to the phi example.
This commit is contained in:
Laurent Mazare
2023-10-15 10:43:10 +01:00
committed by GitHub
parent 8f310cc666
commit b73c35cc57
2 changed files with 49 additions and 75 deletions

View File

@ -28,6 +28,7 @@ struct TextGeneration {
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
@ -40,6 +41,7 @@ impl TextGeneration {
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
@ -49,6 +51,7 @@ impl TextGeneration {
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
@ -58,13 +61,14 @@ impl TextGeneration {
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
@ -129,6 +133,10 @@ struct Args {
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
@ -266,6 +274,7 @@ fn main() -> Result<()> {
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;