From 8e404eb125dee5d7a2a6a6405bc2606e463ba4d6 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 25 Jun 2023 18:26:15 +0100 Subject: [PATCH] Get a some first inference to work on llama. --- examples/llama/main.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/llama/main.rs b/examples/llama/main.rs index c9089e95..e09f5e2f 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -8,6 +8,9 @@ // // In order to convert the llama weights to a .npz file, run: // python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth + +// TODO: This does not use a batch dimension. If adding it back, be cautious about the +// transposition operations. use anyhow::{Error as E, Result}; use clap::Parser; @@ -366,7 +369,9 @@ impl Llama { let x = self.ln_f.forward(&x)?; let x = x.narrow(0, t - 1, 1)?; let logits = self.lm_head.forward(&x)?; - Ok(logits) + let (b, vocab_size) = logits.shape().r2()?; + assert_eq!(b, 1); + Ok(logits.reshape(vocab_size)?) } }