Use the tokenizer-output-stream in the llama example. (#1715)

* Use the tokenizer-output-stream in the llama example.

* Also use tokenizer-output-stream for llama2-c.
This commit is contained in:
Laurent Mazare
2024-02-15 16:47:33 +01:00
committed by GitHub
parent 058a910d0e
commit 7c7400fb63
4 changed files with 17 additions and 20 deletions

View File

@ -57,7 +57,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)] #[arg(long, default_value_t = 10000)]
sample_len: usize, sample_len: usize,
/// Disable the key-value cache. /// Disable the key-value cache.
@ -143,7 +143,6 @@ fn main() -> Result<()> {
} }
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
}; };
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
@ -157,6 +156,7 @@ fn main() -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
@ -190,18 +190,16 @@ fn main() -> Result<()> {
token_generated += 1; token_generated += 1;
tokens.push(next_token); tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id { if Some(next_token) == eos_token_id {
break; break;
} }
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(

View File

@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0.. { for index in 0.. {
@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?; let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple if let Some(t) = tokenizer.next_token(next_token)? {
// heuristics as it seems to work well enough for this example. See the following for more print!("{t}");
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }
} }
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(
"\n{} tokens generated ({:.2} token/s)\n", "\n{} tokens generated ({:.2} token/s)\n",

View File

@ -152,7 +152,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize, sample_len: usize,
#[arg(long)] #[arg(long)]

View File

@ -143,7 +143,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")] #[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]