mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Rework the embeddings so that it works on non-contiguous weights + factor out some code.
This commit is contained in:
@ -350,7 +350,8 @@ impl Llama {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let (_, t) = x.shape().r2()?;
|
||||
// TODO: Support for mini-batches? (i.e. r2)
|
||||
let t = x.shape().r1()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, freqs_cis)?;
|
||||
@ -427,7 +428,7 @@ fn main() -> Result<()> {
|
||||
let mut rng = thread_rng();
|
||||
for index in 0..args.sample_len {
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?;
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
|
Reference in New Issue
Block a user