Better handling of zero temperatures. (#532)

This commit is contained in:
Laurent Mazare
2023-08-21 07:51:46 +01:00
committed by GitHub
parent 8c232d706b
commit 912561614f

View File

@ -16,7 +16,8 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
let temperature = self.temperature.unwrap_or(0.);
let next_token = if temperature > 0. {
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;