Support top-k in tthe llama example. (#2150)

This commit is contained in:
Laurent Mazare
2024-05-01 22:25:47 +02:00
committed by GitHub
parent fa06f5f5f9
commit a09d451d11

View File

@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
@ -54,12 +54,16 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)]
#[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
@ -166,7 +170,21 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;