From 2653002f292a0b1b86d15eadb42a35fb40ee7876 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 14 Apr 2025 15:42:42 +0200 Subject: [PATCH] Gumbel-Softmax sampling. (#2894) * Gumbel-Softmax sampling. * Add a sampling test. * Share the gumbel-softmax bits. --- candle-examples/examples/helium/main.rs | 2 +- candle-nn/src/lib.rs | 1 + candle-nn/src/sampling.rs | 20 +++++++++++++++++ candle-transformers/src/generation/mod.rs | 10 +++++++++ candle-transformers/tests/generation_tests.rs | 22 +++++++++++++++++++ 5 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 candle-nn/src/sampling.rs diff --git a/candle-examples/examples/helium/main.rs b/candle-examples/examples/helium/main.rs index fc7e6b60..7be5f163 100644 --- a/candle-examples/examples/helium/main.rs +++ b/candle-examples/examples/helium/main.rs @@ -46,7 +46,7 @@ impl TextGeneration { Sampling::ArgMax } else { match (top_k, top_p) { - (None, None) => Sampling::All { temperature }, + (None, None) => Sampling::GumbelSoftmax { temperature }, (Some(k), None) => Sampling::TopK { k, temperature }, (None, Some(p)) => Sampling::TopP { p, temperature }, (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 2113566d..d21f12f5 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -31,6 +31,7 @@ pub mod ops; pub mod optim; pub mod rnn; pub mod rotary_emb; +pub mod sampling; pub mod sequential; pub mod var_builder; pub mod var_map; diff --git a/candle-nn/src/sampling.rs b/candle-nn/src/sampling.rs new file mode 100644 index 00000000..ff2785c0 --- /dev/null +++ b/candle-nn/src/sampling.rs @@ -0,0 +1,20 @@ +use candle::{Result, Tensor}; + +/// Sample according to the Gumbel-Softmax distribution. +pub fn gumbel_softmax( + logits: &Tensor, + temperature: f64, + dim: D, +) -> Result { + if temperature <= 0.0 { + logits.argmax(dim) + } else if temperature == 1.0 { + let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?; + let sampled = (logits - minus_g)?.argmax(dim)?; + Ok(sampled) + } else { + let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?; + let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?; + Ok(sampled) + } +} diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index b4d37a6c..d3aee686 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -13,6 +13,8 @@ pub enum Sampling { TopK { k: usize, temperature: f64 }, TopP { p: f64, temperature: f64 }, TopKThenTopP { k: usize, p: f64, temperature: f64 }, + // Note that the rng is not used for the Gumbel-Softmax sampling. + GumbelSoftmax { temperature: f64 }, } pub struct LogitsProcessor { @@ -49,6 +51,11 @@ impl LogitsProcessor { Ok(next_token) } + fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result { + let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?; + sampled.to_vec0::() + } + fn sample_multinomial(&mut self, prs: &Vec) -> Result { let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?; let next_token = distr.sample(&mut self.rng) as u32; @@ -127,6 +134,9 @@ impl LogitsProcessor { let next_token = match &self.sampling { Sampling::ArgMax => self.sample_argmax(logits)?, + Sampling::GumbelSoftmax { temperature } => { + self.sample_gumbel_softmax(&logits, *temperature)? + } Sampling::All { temperature } => { let prs = prs(*temperature)?; self.sample_multinomial(&prs)? diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs index cc499a44..ee7df169 100644 --- a/candle-transformers/tests/generation_tests.rs +++ b/candle-transformers/tests/generation_tests.rs @@ -54,3 +54,25 @@ fn sample_with_top_k() -> Result<()> { assert_eq!(token, 2); Ok(()) } + +#[test] +fn sample_gumbel() -> Result<()> { + let mut logits_process = LogitsProcessor::from_sampling( + 42, + candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 }, + ); + let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?; + let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::()?; + let mut counts = vec![0f64; 4]; + let samples = 100000; + for _ in 0..samples { + let token = logits_process.sample(&logits)?; + counts[token as usize] += 1f64 / samples as f64; + } + for i in 0..4 { + if (counts[i] - sm[i]).abs() > 0.05 { + panic!("pr mismatch {counts:?} {sm:?}"); + } + } + Ok(()) +}