mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support local weights & dynamic outputs (#447)
* Support local weights & dynamic outputs * Revise as suggested * Cargo code format
This commit is contained in:
@ -123,6 +123,11 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The folder name that contains safetensor weights and json files
|
||||
/// (same structure as huggingface online)
|
||||
#[arg(long)]
|
||||
local_weights: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -165,14 +170,26 @@ fn main() -> Result<()> {
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
|
||||
let tokenizer_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||
_ => api.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
match &args.local_weights {
|
||||
Some(path) => {
|
||||
filenames.push((path.to_owned() + rfilename).into());
|
||||
}
|
||||
_ => {
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
println!("building the model");
|
||||
@ -202,8 +219,8 @@ fn main() -> Result<()> {
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
@ -216,22 +233,29 @@ fn main() -> Result<()> {
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
|
||||
let tk = tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
if [",", ".", ":", "?", "'", "\""].contains(&tk.as_str())
|
||||
|| index == args.sample_len - 1
|
||||
|| next_token == 2
|
||||
{
|
||||
//2 for end token
|
||||
print!("{} ", tokenizer.decode(&new_tokens, true).map_err(E::msg)?);
|
||||
new_tokens.clear();
|
||||
}
|
||||
|
||||
if next_token == 2 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user