mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Allow for arbitrary temperature modifications.
This commit is contained in:
@ -87,11 +87,17 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||
self.sample_f(logits, |_| {})
|
||||
}
|
||||
|
||||
pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let prs = |temperature: f64| -> Result<Vec<f32>> {
|
||||
let logits = (&logits / temperature)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
|
||||
prs.to_vec1()
|
||||
let mut prs = prs.to_vec1()?;
|
||||
f(&mut prs);
|
||||
Ok(prs)
|
||||
};
|
||||
|
||||
let next_token = match &self.sampling {
|
||||
|
Reference in New Issue
Block a user