Fix the gumbel softmax by casting to f32. (#2928)

This commit is contained in:
Laurent Mazare
2025-04-28 19:48:51 +02:00
committed by GitHub
parent e98754fc5a
commit d4bac37a61

View File

@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
) -> Result<Tensor> {
if temperature <= 0.0 {
logits.argmax(dim)
} else if temperature == 1.0 {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
let logits = logits.to_dtype(candle::DType::F32)?;
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
if temperature == 1.0 {
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
}
}
}