mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Get a some first inference to work on llama.
This commit is contained in:
@ -8,6 +8,9 @@
|
||||
//
|
||||
// In order to convert the llama weights to a .npz file, run:
|
||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||
|
||||
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
|
||||
// transposition operations.
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
@ -366,7 +369,9 @@ impl Llama {
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.narrow(0, t - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
Ok(logits)
|
||||
let (b, vocab_size) = logits.shape().r2()?;
|
||||
assert_eq!(b, 1);
|
||||
Ok(logits.reshape(vocab_size)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user