mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Allow for arbitrary temperature modifications.
This commit is contained in:
@ -87,11 +87,17 @@ impl LogitsProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||||
|
self.sample_f(logits, |_| {})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
|
||||||
let logits = logits.to_dtype(DType::F32)?;
|
let logits = logits.to_dtype(DType::F32)?;
|
||||||
let prs = |temperature: f64| -> Result<Vec<f32>> {
|
let prs = |temperature: f64| -> Result<Vec<f32>> {
|
||||||
let logits = (&logits / temperature)?;
|
let logits = (&logits / temperature)?;
|
||||||
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
|
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
|
||||||
prs.to_vec1()
|
let mut prs = prs.to_vec1()?;
|
||||||
|
f(&mut prs);
|
||||||
|
Ok(prs)
|
||||||
};
|
};
|
||||||
|
|
||||||
let next_token = match &self.sampling {
|
let next_token = match &self.sampling {
|
||||||
|
Reference in New Issue
Block a user