diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index b1e112fd..9cb9d91d 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -123,6 +123,11 @@ struct Args { #[arg(long)] use_flash_attn: bool, + + /// The folder name that contains safetensor weights and json files + /// (same structure as huggingface online) + #[arg(long)] + local_weights: Option, } fn main() -> Result<()> { @@ -165,14 +170,26 @@ fn main() -> Result<()> { }); println!("loading the model weights from {model_id}"); let api = api.model(model_id); - let tokenizer_filename = api.get("tokenizer.json")?; + + let tokenizer_filename = match &args.local_weights { + Some(path) => (path.to_owned() + "tokenizer.json").into(), + _ => api.get("tokenizer.json")?, + }; + let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(rfilename)?; - filenames.push(filename); + match &args.local_weights { + Some(path) => { + filenames.push((path.to_owned() + rfilename).into()); + } + _ => { + let filename = api.get(rfilename)?; + filenames.push(filename); + } + }; } println!("building the model"); @@ -202,8 +219,8 @@ fn main() -> Result<()> { let mut new_tokens = vec![]; let start_gen = std::time::Instant::now(); let mut index_pos = 0; + let mut token_generated = 0; for index in 0..args.sample_len { - let start_gen = std::time::Instant::now(); let context_size = if cache.use_kv_cache && index > 0 { 1 } else { @@ -216,22 +233,29 @@ fn main() -> Result<()> { index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; + token_generated += 1; tokens.push(next_token); new_tokens.push(next_token); - println!("> {:?}", start_gen.elapsed()); - println!( - "{} token: {} '{}'", - index + 1, - next_token, - tokenizer.decode(&[next_token], true).map_err(E::msg)? - ); + + let tk = tokenizer.decode(&[next_token], true).map_err(E::msg)?; + if [",", ".", ":", "?", "'", "\""].contains(&tk.as_str()) + || index == args.sample_len - 1 + || next_token == 2 + { + //2 for end token + print!("{} ", tokenizer.decode(&new_tokens, true).map_err(E::msg)?); + new_tokens.clear(); + } + + if next_token == 2 { + break; + } } let dt = start_gen.elapsed(); println!( - "{} tokens generated ({} token/s)\n----\n{}\n----", - args.sample_len, - args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(&new_tokens, true).map_err(E::msg)? + "\n\n{} tokens generated ({} token/s)\n", + token_generated, + token_generated as f64 / dt.as_secs_f64(), ); Ok(()) }