mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a prompt and support more models in llama2-c. (#285)
* Support more models in llama2-c. * Add a prompt.
This commit is contained in:
@ -193,6 +193,15 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||
model_id: String,
|
||||
|
||||
/// The model to be used when getting it from the hub. Possible
|
||||
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
|
||||
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||
#[arg(long, default_value = "stories15M.bin")]
|
||||
which_model: String,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
@ -206,7 +215,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
println!("loading the model weights from {}", args.model_id);
|
||||
let api = api.model(args.model_id);
|
||||
api.get("stories15M.bin")?
|
||||
api.get(&args.which_model)?
|
||||
}
|
||||
};
|
||||
let mut file = std::fs::File::open(&config_path)?;
|
||||
@ -226,15 +235,24 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
};
|
||||
println!("{tokenizer_path:?}");
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||
let mut index_pos = 0;
|
||||
let mut tokens = vec![1u32];
|
||||
|
||||
print!("{}", args.prompt);
|
||||
let mut tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..config.seq_len - 10 {
|
||||
for index in 0.. {
|
||||
if tokens.len() >= config.seq_len {
|
||||
break;
|
||||
}
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
|
Reference in New Issue
Block a user