diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 20a6267c..dca85ead 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -193,6 +193,15 @@ struct Args { #[arg(long, default_value = "karpathy/tinyllamas")] model_id: String, + + /// The model to be used when getting it from the hub. Possible + /// values are 'stories15M.bin', 'stories42M.bin', see more at: + /// https://huggingface.co/karpathy/tinyllamas/tree/main + #[arg(long, default_value = "stories15M.bin")] + which_model: String, + + #[arg(long, default_value = "")] + prompt: String, } fn main() -> anyhow::Result<()> { @@ -206,7 +215,7 @@ fn main() -> anyhow::Result<()> { let api = hf_hub::api::sync::Api::new()?; println!("loading the model weights from {}", args.model_id); let api = api.model(args.model_id); - api.get("stories15M.bin")? + api.get(&args.which_model)? } }; let mut file = std::fs::File::open(&config_path)?; @@ -226,15 +235,24 @@ fn main() -> anyhow::Result<()> { } }; println!("{tokenizer_path:?}"); - let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?; + let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?; println!("starting the inference loop"); let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); let mut index_pos = 0; - let mut tokens = vec![1u32]; + + print!("{}", args.prompt); + let mut tokens = tokenizer + .encode(args.prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); let start_gen = std::time::Instant::now(); - for index in 0..config.seq_len - 10 { + for index in 0.. { + if tokens.len() >= config.seq_len { + break; + } let start_gen = std::time::Instant::now(); let context_size = if cache.use_kv_cache && index > 0 { 1 diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 9e1c3eda..fbeb4038 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -112,8 +112,10 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { let (b_sz, seq_len, h, n_embd) = x.dims4()?; - let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; - let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; + let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; + let cos = cos.unsqueeze(1)?; + let sin = sin.unsqueeze(1)?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;