mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Take as input slices of tensors as well as slices of &Tensors.
This commit is contained in:
@ -427,7 +427,7 @@ fn main() -> Result<()> {
|
||||
let mut rng = thread_rng();
|
||||
for index in 0..args.sample_len {
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?;
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
|
Reference in New Issue
Block a user