Gumbel-Softmax sampling. (#2894)

* Gumbel-Softmax sampling.

* Add a sampling test.

* Share the gumbel-softmax bits.
This commit is contained in:
Laurent Mazare
2025-04-14 15:42:42 +02:00
committed by GitHub
parent a52b76ae82
commit 2653002f29
5 changed files with 54 additions and 1 deletions

View File

@ -31,6 +31,7 @@ pub mod ops;
pub mod optim;
pub mod rnn;
pub mod rotary_emb;
pub mod sampling;
pub mod sequential;
pub mod var_builder;
pub mod var_map;

20
candle-nn/src/sampling.rs Normal file
View File

@ -0,0 +1,20 @@
use candle::{Result, Tensor};
/// Sample according to the Gumbel-Softmax distribution.
pub fn gumbel_softmax<D: candle::shape::Dim>(
logits: &Tensor,
temperature: f64,
dim: D,
) -> Result<Tensor> {
if temperature <= 0.0 {
logits.argmax(dim)
} else if temperature == 1.0 {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
}
}