Support local weights & dynamic outputs (#447)

* Support local weights & dynamic outputs

* Revise as suggested

* Cargo code format
This commit is contained in:
Guoqing Bao
2023-08-15 18:51:57 +08:00
committed by GitHub
parent 531f23b4d0
commit 3cc87058b7

View File

@ -123,6 +123,11 @@ struct Args {
#[arg(long)]
use_flash_attn: bool,
/// The folder name that contains safetensor weights and json files
/// (same structure as huggingface online)
#[arg(long)]
local_weights: Option<String>,
}
fn main() -> Result<()> {
@ -165,14 +170,26 @@ fn main() -> Result<()> {
});
println!("loading the model weights from {model_id}");
let api = api.model(model_id);
let tokenizer_filename = api.get("tokenizer.json")?;
let tokenizer_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "tokenizer.json").into(),
_ => api.get("tokenizer.json")?,
};
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(rfilename)?;
filenames.push(filename);
match &args.local_weights {
Some(path) => {
filenames.push((path.to_owned() + rfilename).into());
}
_ => {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
};
}
println!("building the model");
@ -202,8 +219,8 @@ fn main() -> Result<()> {
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let context_size = if cache.use_kv_cache && index > 0 {
1
} else {
@ -216,22 +233,29 @@ fn main() -> Result<()> {
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
let tk = tokenizer.decode(&[next_token], true).map_err(E::msg)?;
if [",", ".", ":", "?", "'", "\""].contains(&tk.as_str())
|| index == args.sample_len - 1
|| next_token == 2
{
//2 for end token
print!("{} ", tokenizer.decode(&new_tokens, true).map_err(E::msg)?);
new_tokens.clear();
}
if next_token == 2 {
break;
}
}
let dt = start_gen.elapsed();
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(&new_tokens, true).map_err(E::msg)?
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
Ok(())
}