mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Populate the no-speech probability.
This commit is contained in:
@ -35,6 +35,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: u32 = 50257;
|
||||
const EOT_TOKEN: u32 = 50256;
|
||||
const NO_SPEECH_TOKEN: u32 = 50361;
|
||||
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -87,9 +89,9 @@ impl Decode {
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.n_text_ctx / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let no_speech_prob = f64::NAN;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
for _i in 0..sample_len {
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
@ -97,10 +99,21 @@ impl Decode {
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = logits
|
||||
.get(0)?
|
||||
.softmax(0)?
|
||||
.get(NO_SPEECH_TOKEN as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (seq_len, _) = logits.shape().r2()?;
|
||||
let logits = logits.get(seq_len - 1)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = (&logits / t)?.softmax(logits.rank() - 1)?;
|
||||
let prs = (&logits / t)?.softmax(0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
@ -146,11 +159,16 @@ impl Decode {
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
if let Ok(dr) = dr {
|
||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
match dr {
|
||||
Ok(dr) => {
|
||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!("Error running at {t}: {err}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user