mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Get a some first inference to work on llama.
This commit is contained in:
@ -8,6 +8,9 @@
|
|||||||
//
|
//
|
||||||
// In order to convert the llama weights to a .npz file, run:
|
// In order to convert the llama weights to a .npz file, run:
|
||||||
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
|
||||||
|
|
||||||
|
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
|
||||||
|
// transposition operations.
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
@ -366,7 +369,9 @@ impl Llama {
|
|||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let x = x.narrow(0, t - 1, 1)?;
|
let x = x.narrow(0, t - 1, 1)?;
|
||||||
let logits = self.lm_head.forward(&x)?;
|
let logits = self.lm_head.forward(&x)?;
|
||||||
Ok(logits)
|
let (b, vocab_size) = logits.shape().r2()?;
|
||||||
|
assert_eq!(b, 1);
|
||||||
|
Ok(logits.reshape(vocab_size)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user