mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support top-k in tthe llama example. (#2150)
This commit is contained in:
@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
|
|||||||
|
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
@ -54,12 +54,16 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 10000)]
|
#[arg(short = 'n', long, default_value_t = 10000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// Disable the key-value cache.
|
/// Disable the key-value cache.
|
||||||
@ -166,7 +170,21 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
let mut start_gen = std::time::Instant::now();
|
let mut start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
|
Reference in New Issue
Block a user