From a09d451d11a91ea7a7feaa40460abb282581a0f1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 1 May 2024 22:25:47 +0200 Subject: [PATCH] Support top-k in tthe llama example. (#2150) --- candle-examples/examples/llama/main.rs | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 72656295..fa7ce81b 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum}; use candle::{DType, Tensor}; use candle_nn::VarBuilder; -use candle_transformers::generation::LogitsProcessor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; @@ -54,12 +54,16 @@ struct Args { #[arg(long)] top_p: Option, + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, default_value_t = 10000)] + #[arg(short = 'n', long, default_value_t = 10000)] sample_len: usize, /// Disable the key-value cache. @@ -166,7 +170,21 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); - let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p); + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + let mut start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0;