mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix the gumbel softmax by casting to f32. (#2928)
This commit is contained in:
@ -8,13 +8,16 @@ pub fn gumbel_softmax<D: candle::shape::Dim>(
|
|||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
if temperature <= 0.0 {
|
if temperature <= 0.0 {
|
||||||
logits.argmax(dim)
|
logits.argmax(dim)
|
||||||
} else if temperature == 1.0 {
|
} else {
|
||||||
|
// Cast to f32, doing the Gumbel softmax in bf16 is a bit unstable.
|
||||||
|
let logits = logits.to_dtype(candle::DType::F32)?;
|
||||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||||
|
if temperature == 1.0 {
|
||||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||||
Ok(sampled)
|
Ok(sampled)
|
||||||
} else {
|
} else {
|
||||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
|
||||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||||
Ok(sampled)
|
Ok(sampled)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user