diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index c00af3fe..6aa3f51e 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; -use candle_transformers::generation::LogitsProcessor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; @@ -39,11 +39,26 @@ impl TextGeneration { seed: u64, temp: Option, top_p: Option, + top_k: Option, repeat_penalty: f32, repeat_last_n: usize, device: &Device, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp, top_p); + let logits_processor = { + let temperature = temp.unwrap_or(0.); + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (top_k, top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + Self { model, tokenizer: TokenOutputStream::new(tokenizer), @@ -159,6 +174,10 @@ struct Args { #[arg(long)] top_p: Option, + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -314,6 +333,7 @@ fn main() -> Result<()> { args.seed, args.temperature, args.top_p, + args.top_k, args.repeat_penalty, args.repeat_last_n, &device, diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index b03768ed..ea7f70eb 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -10,7 +10,7 @@ use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; use candle::Tensor; -use candle_transformers::generation::LogitsProcessor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; use candle_transformers::models::quantized_llama as model; @@ -200,6 +200,10 @@ struct Args { #[arg(long)] top_p: Option, + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -349,11 +353,6 @@ fn main() -> anyhow::Result<()> { #[cfg(feature = "cuda")] candle::quantized::cuda::set_force_dmmv(args.force_dmmv); - let temperature = if args.temperature == 0. { - None - } else { - Some(args.temperature) - }; let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -500,7 +499,20 @@ fn main() -> anyhow::Result<()> { prompt_tokens }; let mut all_tokens = vec![]; - let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; let start_prompt_processing = std::time::Instant::now(); let mut next_token = if !args.split_prompt { diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index 257d9171..c250a186 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -7,6 +7,7 @@ pub enum Sampling { All { temperature: f64 }, TopK { k: usize, temperature: f64 }, TopP { p: f64, temperature: f64 }, + TopKThenTopP { k: usize, p: f64, temperature: f64 }, } pub struct LogitsProcessor { @@ -77,7 +78,6 @@ impl LogitsProcessor { self.sample_multinomial(prs) } else { let mut argsort_indices = (0..prs.len()).collect::>(); - // Sort by descending probability. let (indices, _, _) = argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); let prs = indices.iter().map(|&i| prs[i]).collect::>(); @@ -86,6 +86,26 @@ impl LogitsProcessor { } } + // top-k sampling samples from the k tokens with the largest probabilities. + // then top-p sampling. + fn sample_topk_topp(&mut self, prs: &mut Vec, top_k: usize, top_p: f32) -> Result { + if top_k >= prs.len() { + self.sample_topp(prs, top_p) + } else { + let mut argsort_indices = (0..prs.len()).collect::>(); + let (indices, _, _) = + argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); + let mut prs = indices.iter().map(|&i| prs[i]).collect::>(); + let sum_p = prs.iter().sum::(); + let index = if top_p <= 0.0 || top_p >= sum_p { + self.sample_multinomial(&prs)? + } else { + self.sample_topp(&mut prs, top_p)? + }; + Ok(indices[index as usize] as u32) + } + } + pub fn sample(&mut self, logits: &Tensor) -> Result { self.sample_f(logits, |_| {}) } @@ -120,6 +140,10 @@ impl LogitsProcessor { let mut prs = prs(*temperature)?; self.sample_topk(&mut prs, *k)? } + Sampling::TopKThenTopP { k, p, temperature } => { + let mut prs = prs(*temperature)?; + self.sample_topk_topp(&mut prs, *k, *p as f32)? + } }; Ok(next_token) }