From 5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 23 Mar 2024 15:47:39 +0100 Subject: [PATCH] Allow for arbitrary temperature modifications. --- candle-transformers/src/generation/mod.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index 530a6b48..257d9171 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -87,11 +87,17 @@ impl LogitsProcessor { } pub fn sample(&mut self, logits: &Tensor) -> Result { + self.sample_f(logits, |_| {}) + } + + pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result { let logits = logits.to_dtype(DType::F32)?; let prs = |temperature: f64| -> Result> { let logits = (&logits / temperature)?; let prs = candle_nn::ops::softmax_last_dim(&logits)?; - prs.to_vec1() + let mut prs = prs.to_vec1()?; + f(&mut prs); + Ok(prs) }; let next_token = match &self.sampling {