mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Gumbel-Softmax sampling. (#2894)
* Gumbel-Softmax sampling. * Add a sampling test. * Share the gumbel-softmax bits.
This commit is contained in:
@ -54,3 +54,25 @@ fn sample_with_top_k() -> Result<()> {
|
||||
assert_eq!(token, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_gumbel() -> Result<()> {
|
||||
let mut logits_process = LogitsProcessor::from_sampling(
|
||||
42,
|
||||
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
|
||||
);
|
||||
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
|
||||
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
|
||||
let mut counts = vec![0f64; 4];
|
||||
let samples = 100000;
|
||||
for _ in 0..samples {
|
||||
let token = logits_process.sample(&logits)?;
|
||||
counts[token as usize] += 1f64 / samples as f64;
|
||||
}
|
||||
for i in 0..4 {
|
||||
if (counts[i] - sm[i]).abs() > 0.05 {
|
||||
panic!("pr mismatch {counts:?} {sm:?}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user