Simplify the pattern matching logic in the cuda backend.

This commit is contained in:
laurent
2023-06-29 09:21:11 +01:00
parent eda46d2df2
commit 122e334d0c
5 changed files with 78 additions and 89 deletions

View File

@ -487,6 +487,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng();
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis)?;
@ -496,6 +497,7 @@ fn main() -> Result<()> {
let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,