mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Fix for whisper example. rand::distribution is now rand::distr (#2811)
This commit is contained in:

committed by
GitHub

parent
3afb04925a
commit
0b24f7f0a4
@ -14,7 +14,9 @@ use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use rand::distr::weighted::WeightedIndex;
|
||||
use rand::distr::Distribution;
|
||||
use rand::SeedableRng;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
@ -208,7 +210,7 @@ impl Decoder {
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let distr = WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
|
Reference in New Issue
Block a user