From 912561614f0fb0fc1e9ccff49448d3d2a85302ce Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 21 Aug 2023 07:51:46 +0100 Subject: [PATCH] Better handling of zero temperatures. (#532) --- candle-transformers/src/generation/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index d2ac33e9..b1d20168 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -16,7 +16,8 @@ impl LogitsProcessor { pub fn sample(&mut self, logits: &Tensor) -> Result { let logits = logits.to_dtype(DType::F32)?; - let next_token = if let Some(temperature) = self.temperature { + let temperature = self.temperature.unwrap_or(0.); + let next_token = if temperature > 0. { let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; let prs: Vec = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;