Llama batch (#144)

* Add a batch dimension to llama.

* Bugfixes.
This commit is contained in:
Laurent Mazare
2023-07-12 11:38:19 +01:00
committed by GitHub
parent bcf96e3cf3
commit b3b39cca92
3 changed files with 32 additions and 52 deletions

View File

@ -12,8 +12,6 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
// transposition operations.
use anyhow::{Error as E, Result};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
@ -200,13 +198,14 @@ fn main() -> Result<()> {
tokens.len()
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?;
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let freqs_cis = if cache.use_kv_cache {
freqs_cis.narrow(1, index_pos, ctxt.len())?
} else {
freqs_cis.clone()
};
let logits = llama.forward(&input, &freqs_cis)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = if let Some(temperature) = args.temperature {