mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
@ -4,7 +4,7 @@
|
||||
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
|
||||
//! and combinations thereof.
|
||||
use candle::{Context, DType, Error, Result, Tensor};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use rand::{distr::Distribution, SeedableRng};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Sampling {
|
||||
@ -50,7 +50,7 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let next_token = distr.sample(&mut self.rng) as u32;
|
||||
Ok(next_token)
|
||||
}
|
||||
|
Reference in New Issue
Block a user