Get a some first inference to work on llama.

This commit is contained in:
laurent
2023-06-25 18:26:15 +01:00
parent 87c5aab005
commit 8e404eb125

View File

@ -8,6 +8,9 @@
//
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
// transposition operations.
use anyhow::{Error as E, Result};
use clap::Parser;
@ -366,7 +369,9 @@ impl Llama {
let x = self.ln_f.forward(&x)?;
let x = x.narrow(0, t - 1, 1)?;
let logits = self.lm_head.forward(&x)?;
Ok(logits)
let (b, vocab_size) = logits.shape().r2()?;
assert_eq!(b, 1);
Ok(logits.reshape(vocab_size)?)
}
}