mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Include topk sampling in the quantized example. (#2005)
* Include topk sampling in the quantized example. * Also sample with top-k on the mistral side.
This commit is contained in:
@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
@ -39,11 +39,26 @@ impl TextGeneration {
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
top_k: Option<usize>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
let logits_processor = {
|
||||
let temperature = temp.unwrap_or(0.);
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
@ -159,6 +174,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -314,6 +333,7 @@ fn main() -> Result<()> {
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
|
Reference in New Issue
Block a user