Allow for arbitrary temperature modifications.

This commit is contained in:
laurent
2024-03-23 15:47:39 +01:00
parent a62a97340c
commit 5e70821dd0

View File

@ -87,11 +87,17 @@ impl LogitsProcessor {
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
self.sample_f(logits, |_| {})
}
pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let prs = |temperature: f64| -> Result<Vec<f32>> {
let logits = (&logits / temperature)?;
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
prs.to_vec1()
let mut prs = prs.to_vec1()?;
f(&mut prs);
Ok(prs)
};
let next_token = match &self.sampling {