Gumbel-Softmax sampling. (#2894)

* Gumbel-Softmax sampling.

* Add a sampling test.

* Share the gumbel-softmax bits.
This commit is contained in:
Laurent Mazare
2025-04-14 15:42:42 +02:00
committed by GitHub
parent a52b76ae82
commit 2653002f29
5 changed files with 54 additions and 1 deletions

View File

@ -46,7 +46,7 @@ impl TextGeneration {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::All { temperature },
(None, None) => Sampling::GumbelSoftmax { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },