mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Support top-k in tthe llama example. (#2150)
This commit is contained in:
@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
@ -54,12 +54,16 @@ struct Args {
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 10000)]
|
||||
#[arg(short = 'n', long, default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
@ -166,7 +170,21 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
|
Reference in New Issue
Block a user