diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index 530a6b48..257d9171 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -87,11 +87,17 @@ impl LogitsProcessor { } pub fn sample(&mut self, logits: &Tensor) -> Result { + self.sample_f(logits, |_| {}) + } + + pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result { let logits = logits.to_dtype(DType::F32)?; let prs = |temperature: f64| -> Result> { let logits = (&logits / temperature)?; let prs = candle_nn::ops::softmax_last_dim(&logits)?; - prs.to_vec1() + let mut prs = prs.to_vec1()?; + f(&mut prs); + Ok(prs) }; let next_token = match &self.sampling {