mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use the tokenizer-output-stream in the llama example. (#1715)
* Use the tokenizer-output-stream in the llama example. * Also use tokenizer-output-stream for llama2-c.
This commit is contained in:
@ -57,7 +57,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 100)]
|
#[arg(long, default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// Disable the key-value cache.
|
/// Disable the key-value cache.
|
||||||
@ -143,7 +143,6 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
};
|
};
|
||||||
println!("building the model");
|
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
@ -157,6 +156,7 @@ fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
@ -190,18 +190,16 @@ fn main() -> Result<()> {
|
|||||||
token_generated += 1;
|
token_generated += 1;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
|
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
|
||||||
// details:
|
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
print!("{text}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
if Some(next_token) == eos_token_id {
|
if Some(next_token) == eos_token_id {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
|
@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0.. {
|
for index in 0.. {
|
||||||
@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
print!("{t}");
|
||||||
// details:
|
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
print!("{text}");
|
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n{} tokens generated ({:.2} token/s)\n",
|
"\n{} tokens generated ({:.2} token/s)\n",
|
||||||
|
@ -152,7 +152,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
|
@ -143,7 +143,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||||
|
Reference in New Issue
Block a user