mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Simplify the pattern matching logic in the cuda backend.
This commit is contained in:
@ -487,6 +487,7 @@ fn main() -> Result<()> {
|
||||
let mut rng = thread_rng();
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
@ -496,6 +497,7 @@ fn main() -> Result<()> {
|
||||
let next_token = distr.sample(&mut rng) as u32;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
|
Reference in New Issue
Block a user