mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support local weights & dynamic outputs (#447)
* Support local weights & dynamic outputs * Revise as suggested * Cargo code format
This commit is contained in:
@ -123,6 +123,11 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
/// The folder name that contains safetensor weights and json files
|
||||||
|
/// (same structure as huggingface online)
|
||||||
|
#[arg(long)]
|
||||||
|
local_weights: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -165,14 +170,26 @@ fn main() -> Result<()> {
|
|||||||
});
|
});
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let api = api.model(model_id);
|
let api = api.model(model_id);
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
|
||||||
|
let tokenizer_filename = match &args.local_weights {
|
||||||
|
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||||
|
_ => api.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
|
||||||
let mut filenames = vec![];
|
let mut filenames = vec![];
|
||||||
for rfilename in [
|
for rfilename in [
|
||||||
"model-00001-of-00002.safetensors",
|
"model-00001-of-00002.safetensors",
|
||||||
"model-00002-of-00002.safetensors",
|
"model-00002-of-00002.safetensors",
|
||||||
] {
|
] {
|
||||||
let filename = api.get(rfilename)?;
|
match &args.local_weights {
|
||||||
filenames.push(filename);
|
Some(path) => {
|
||||||
|
filenames.push((path.to_owned() + rfilename).into());
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let filename = api.get(rfilename)?;
|
||||||
|
filenames.push(filename);
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
@ -202,8 +219,8 @@ fn main() -> Result<()> {
|
|||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
|
let mut token_generated = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
let context_size = if cache.use_kv_cache && index > 0 {
|
let context_size = if cache.use_kv_cache && index > 0 {
|
||||||
1
|
1
|
||||||
} else {
|
} else {
|
||||||
@ -216,22 +233,29 @@ fn main() -> Result<()> {
|
|||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let next_token = logits_processor.sample(&logits)?;
|
||||||
|
token_generated += 1;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
println!("> {:?}", start_gen.elapsed());
|
|
||||||
println!(
|
let tk = tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||||
"{} token: {} '{}'",
|
if [",", ".", ":", "?", "'", "\""].contains(&tk.as_str())
|
||||||
index + 1,
|
|| index == args.sample_len - 1
|
||||||
next_token,
|
|| next_token == 2
|
||||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
{
|
||||||
);
|
//2 for end token
|
||||||
|
print!("{} ", tokenizer.decode(&new_tokens, true).map_err(E::msg)?);
|
||||||
|
new_tokens.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if next_token == 2 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
"\n\n{} tokens generated ({} token/s)\n",
|
||||||
args.sample_len,
|
token_generated,
|
||||||
args.sample_len as f64 / dt.as_secs_f64(),
|
token_generated as f64 / dt.as_secs_f64(),
|
||||||
tokenizer.decode(&new_tokens, true).map_err(E::msg)?
|
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user