Populate the no-speech probability.

This commit is contained in:
laurent
2023-07-05 08:54:04 +01:00
parent 9694e35db0
commit a824c5c3e3

View File

@ -35,6 +35,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
const NO_TIMESTAMP_TOKEN: u32 = 50362;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@ -87,9 +89,9 @@ impl Decode {
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.n_text_ctx / 2;
let mut sum_logprob = 0f64;
let no_speech_prob = f64::NAN;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![SOT_TOKEN];
for _i in 0..sample_len {
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
// The model expects a batch dim but this inference loop does not handle
@ -97,10 +99,21 @@ impl Decode {
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = logits.squeeze(0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = logits
.get(0)?
.softmax(0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
let (seq_len, _) = logits.shape().r2()?;
let logits = logits.get(seq_len - 1)?;
let next_token = if t > 0f64 {
let prs = (&logits / t)?.softmax(logits.rank() - 1)?;
let prs = (&logits / t)?.softmax(0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
@ -146,11 +159,16 @@ impl Decode {
return dr;
}
// On errors, we try again with a different temperature.
if let Ok(dr) = dr {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
return Ok(dr);
match dr {
Ok(dr) => {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
Err(err) => {
println!("Error running at {t}: {err}")
}
}
}