From d4bac37a61df27742023d5a5b8b31aca697c9307 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 28 Apr 2025 19:48:51 +0200 Subject: [PATCH] Fix the gumbel softmax by casting to f32. (#2928) --- candle-nn/src/sampling.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/candle-nn/src/sampling.rs b/candle-nn/src/sampling.rs index ff2785c0..80227413 100644 --- a/candle-nn/src/sampling.rs +++ b/candle-nn/src/sampling.rs @@ -8,13 +8,16 @@ pub fn gumbel_softmax( ) -> Result { if temperature <= 0.0 { logits.argmax(dim) - } else if temperature == 1.0 { - let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?; - let sampled = (logits - minus_g)?.argmax(dim)?; - Ok(sampled) } else { + // Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable. + let logits = logits.to_dtype(candle::DType::F32)?; let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?; - let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?; - Ok(sampled) + if temperature == 1.0 { + let sampled = (logits - minus_g)?.argmax(dim)?; + Ok(sampled) + } else { + let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?; + Ok(sampled) + } } }