mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
@ -3,7 +3,7 @@ use anyhow::Error as E;
|
||||
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
pub use candle_transformers::models::whisper::{self as m, Config};
|
||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||
use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
@ -221,7 +221,7 @@ impl Decoder {
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
|
Reference in New Issue
Block a user