mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
@ -12,8 +12,6 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
|
||||
// transposition operations.
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
@ -200,13 +198,14 @@ fn main() -> Result<()> {
|
||||
tokens.len()
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let freqs_cis = if cache.use_kv_cache {
|
||||
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
||||
} else {
|
||||
freqs_cis.clone()
|
||||
};
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = if let Some(temperature) = args.temperature {
|
||||
|
Reference in New Issue
Block a user