Rework the embeddings so that it works on non-contiguous weights + factor out some code.

This commit is contained in:
laurent
2023-06-25 17:37:47 +01:00
parent 334524e2c4
commit 817e4b5005
6 changed files with 66 additions and 48 deletions

View File

@ -350,7 +350,8 @@ impl Llama {
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (_, t) = x.shape().r2()?;
// TODO: Support for mini-batches? (i.e. r2)
let t = x.shape().r1()?;
let mut x = self.wte.forward(x)?;
for block in self.blocks.iter() {
x = block.forward(&x, freqs_cis)?;
@ -427,7 +428,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng();
for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
let input = Tensor::new(ctxt, &Device::Cpu)?;
let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;