Add a prompt and support more models in llama2-c. (#285)

* Support more models in llama2-c.

* Add a prompt.
This commit is contained in:
Laurent Mazare
2023-07-31 13:09:30 +01:00
committed by GitHub
parent 94a43faaca
commit b3ea96b62b
2 changed files with 26 additions and 6 deletions

View File

@ -193,6 +193,15 @@ struct Args {
#[arg(long, default_value = "karpathy/tinyllamas")]
model_id: String,
/// The model to be used when getting it from the hub. Possible
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
/// https://huggingface.co/karpathy/tinyllamas/tree/main
#[arg(long, default_value = "stories15M.bin")]
which_model: String,
#[arg(long, default_value = "")]
prompt: String,
}
fn main() -> anyhow::Result<()> {
@ -206,7 +215,7 @@ fn main() -> anyhow::Result<()> {
let api = hf_hub::api::sync::Api::new()?;
println!("loading the model weights from {}", args.model_id);
let api = api.model(args.model_id);
api.get("stories15M.bin")?
api.get(&args.which_model)?
}
};
let mut file = std::fs::File::open(&config_path)?;
@ -226,15 +235,24 @@ fn main() -> anyhow::Result<()> {
}
};
println!("{tokenizer_path:?}");
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
let mut index_pos = 0;
let mut tokens = vec![1u32];
print!("{}", args.prompt);
let mut tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let start_gen = std::time::Instant::now();
for index in 0..config.seq_len - 10 {
for index in 0.. {
if tokens.len() >= config.seq_len {
break;
}
let start_gen = std::time::Instant::now();
let context_size = if cache.use_kv_cache && index > 0 {
1

View File

@ -112,8 +112,10 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;