mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
fixed rand imports for whisper-microphone example (#2834)
This commit is contained in:
@ -9,7 +9,7 @@ use candle::{Device, IndexOp, Tensor};
|
|||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distr::Distribution, SeedableRng};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
@ -204,7 +204,7 @@ impl Decoder {
|
|||||||
let next_token = if t > 0f64 {
|
let next_token = if t > 0f64 {
|
||||||
let prs = softmax(&(&logits / t)?, 0)?;
|
let prs = softmax(&(&logits / t)?, 0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
|
||||||
distr.sample(&mut self.rng) as u32
|
distr.sample(&mut self.rng) as u32
|
||||||
} else {
|
} else {
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
Reference in New Issue
Block a user