Take as input slices of tensors as well as slices of &Tensors.

This commit is contained in:
laurent
2023-06-25 17:07:09 +01:00
parent 8b67f294e8
commit 334524e2c4
2 changed files with 26 additions and 17 deletions

View File

@ -427,7 +427,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng();
for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &Device::Cpu)?;
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;