Clean the decode loop of the whisper example.

This commit is contained in:
laurent
2023-07-05 08:37:26 +01:00
parent fbdabf0325
commit 9694e35db0
2 changed files with 24 additions and 14 deletions

View File

@ -2,7 +2,8 @@
// https://github.com/openai/whisper/blob/main/whisper/model.py
// TODO:
// - kv-cache support?
// - language detection?
// - Language detection?
// - Batch size greater than 1.
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
@ -31,6 +32,10 @@ const LOGPROB_THRESHOLD: f64 = -1.0;
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -83,11 +88,12 @@ impl Decode {
let sample_len = model.config.n_text_ctx / 2;
let mut sum_logprob = 0f64;
let no_speech_prob = f64::NAN;
// TODO: 50257 is the start of transcipt token, be more principled about get initial tokens
let mut tokens: Vec<u32> = vec![50257];
let mut tokens = vec![SOT_TOKEN];
for _i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
// Insert a batch dim.
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = logits.squeeze(0)?;
@ -112,11 +118,10 @@ impl Decode {
.softmax(logits.rank() - 1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
sum_logprob += prob.ln();
// 50256 is the eot token, TODO: parameterize this.
if next_token == 50256 || tokens.len() > model.config.n_text_ctx {
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
break;
}
sum_logprob += prob.ln();
}
let text = self
.tokenizer
@ -136,14 +141,17 @@ impl Decode {
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() {
let dr: DecodingResult = self.decode(segment, t)?;
let dr: Result<DecodingResult> = self.decode(segment, t);
if i == TEMPERATURES.len() - 1 {
return Ok(dr);
return dr;
}
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
return Ok(dr);
// On errors, we try again with a different temperature.
if let Ok(dr) = dr {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
}
unreachable!()
@ -195,7 +203,7 @@ fn main() -> Result<()> {
duration: segment_duration,
dr,
};
println!("{seek} {segment:?}");
println!("{seek}: {segment:?}");
segments.push(segment)
}
Ok(())

View File

@ -1,3 +1,5 @@
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use std::collections::HashMap;