mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Gumbel-Softmax sampling. (#2894)
* Gumbel-Softmax sampling. * Add a sampling test. * Share the gumbel-softmax bits.
This commit is contained in:
@ -31,6 +31,7 @@ pub mod ops;
|
||||
pub mod optim;
|
||||
pub mod rnn;
|
||||
pub mod rotary_emb;
|
||||
pub mod sampling;
|
||||
pub mod sequential;
|
||||
pub mod var_builder;
|
||||
pub mod var_map;
|
||||
|
20
candle-nn/src/sampling.rs
Normal file
20
candle-nn/src/sampling.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// Sample according to the Gumbel-Softmax distribution.
|
||||
pub fn gumbel_softmax<D: candle::shape::Dim>(
|
||||
logits: &Tensor,
|
||||
temperature: f64,
|
||||
dim: D,
|
||||
) -> Result<Tensor> {
|
||||
if temperature <= 0.0 {
|
||||
logits.argmax(dim)
|
||||
} else if temperature == 1.0 {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
} else {
|
||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||
Ok(sampled)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user