Clean the decode loop of the whisper example.

This commit is contained in:
laurent
2023-07-05 08:37:26 +01:00
parent fbdabf0325
commit 9694e35db0
2 changed files with 24 additions and 14 deletions

View File

@ -2,7 +2,8 @@
// https://github.com/openai/whisper/blob/main/whisper/model.py // https://github.com/openai/whisper/blob/main/whisper/model.py
// TODO: // TODO:
// - kv-cache support? // - kv-cache support?
// - language detection? // - Language detection?
// - Batch size greater than 1.
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
@ -31,6 +32,10 @@ const LOGPROB_THRESHOLD: f64 = -1.0;
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -83,11 +88,12 @@ impl Decode {
let sample_len = model.config.n_text_ctx / 2; let sample_len = model.config.n_text_ctx / 2;
let mut sum_logprob = 0f64; let mut sum_logprob = 0f64;
let no_speech_prob = f64::NAN; let no_speech_prob = f64::NAN;
// TODO: 50257 is the start of transcipt token, be more principled about get initial tokens let mut tokens = vec![SOT_TOKEN];
let mut tokens: Vec<u32> = vec![50257];
for _i in 0..sample_len { for _i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?; let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
// Insert a batch dim.
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?; let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?; let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
@ -112,11 +118,10 @@ impl Decode {
.softmax(logits.rank() - 1)? .softmax(logits.rank() - 1)?
.get(next_token as usize)? .get(next_token as usize)?
.to_scalar::<f32>()? as f64; .to_scalar::<f32>()? as f64;
sum_logprob += prob.ln(); if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
// 50256 is the eot token, TODO: parameterize this.
if next_token == 50256 || tokens.len() > model.config.n_text_ctx {
break; break;
} }
sum_logprob += prob.ln();
} }
let text = self let text = self
.tokenizer .tokenizer
@ -136,14 +141,17 @@ impl Decode {
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> { fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() { for (i, &t) in TEMPERATURES.iter().enumerate() {
let dr: DecodingResult = self.decode(segment, t)?; let dr: Result<DecodingResult> = self.decode(segment, t);
if i == TEMPERATURES.len() - 1 { if i == TEMPERATURES.len() - 1 {
return Ok(dr); return dr;
} }
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD // On errors, we try again with a different temperature.
|| dr.avg_logprob < LOGPROB_THRESHOLD; if let Ok(dr) = dr {
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
return Ok(dr); || dr.avg_logprob < LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
return Ok(dr);
}
} }
} }
unreachable!() unreachable!()
@ -195,7 +203,7 @@ fn main() -> Result<()> {
duration: segment_duration, duration: segment_duration,
dr, dr,
}; };
println!("{seek} {segment:?}"); println!("{seek}: {segment:?}");
segments.push(segment) segments.push(segment)
} }
Ok(()) Ok(())

View File

@ -1,3 +1,5 @@
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result; use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use std::collections::HashMap; use std::collections::HashMap;