From 9694e35db08b70ca3a9f60b54cf9f0701425edbd Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 5 Jul 2023 08:37:26 +0100 Subject: [PATCH] Clean the decode loop of the whisper example. --- candle-examples/examples/whisper/main.rs | 36 ++++++++++++++--------- candle-examples/examples/whisper/model.rs | 2 ++ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 36caa183..1446e067 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -2,7 +2,8 @@ // https://github.com/openai/whisper/blob/main/whisper/model.py // TODO: // - kv-cache support? -// - language detection? +// - Language detection? +// - Batch size greater than 1. use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; @@ -31,6 +32,10 @@ const LOGPROB_THRESHOLD: f64 = -1.0; const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; +// Tokenizer dependent bits. +const SOT_TOKEN: u32 = 50257; +const EOT_TOKEN: u32 = 50256; + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -83,11 +88,12 @@ impl Decode { let sample_len = model.config.n_text_ctx / 2; let mut sum_logprob = 0f64; let no_speech_prob = f64::NAN; - // TODO: 50257 is the start of transcipt token, be more principled about get initial tokens - let mut tokens: Vec = vec![50257]; + let mut tokens = vec![SOT_TOKEN]; for _i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?; - // Insert a batch dim. + + // The model expects a batch dim but this inference loop does not handle + // it so we add it at this point. let tokens_t = tokens_t.unsqueeze(0)?; let logits = model.decoder.forward(&tokens_t, &audio_features)?; let logits = logits.squeeze(0)?; @@ -112,11 +118,10 @@ impl Decode { .softmax(logits.rank() - 1)? .get(next_token as usize)? .to_scalar::()? as f64; - sum_logprob += prob.ln(); - // 50256 is the eot token, TODO: parameterize this. - if next_token == 50256 || tokens.len() > model.config.n_text_ctx { + if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx { break; } + sum_logprob += prob.ln(); } let text = self .tokenizer @@ -136,14 +141,17 @@ impl Decode { fn decode_with_fallback(&mut self, segment: &Tensor) -> Result { for (i, &t) in TEMPERATURES.iter().enumerate() { - let dr: DecodingResult = self.decode(segment, t)?; + let dr: Result = self.decode(segment, t); if i == TEMPERATURES.len() - 1 { - return Ok(dr); + return dr; } - let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD - || dr.avg_logprob < LOGPROB_THRESHOLD; - if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { - return Ok(dr); + // On errors, we try again with a different temperature. + if let Ok(dr) = dr { + let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + return Ok(dr); + } } } unreachable!() @@ -195,7 +203,7 @@ fn main() -> Result<()> { duration: segment_duration, dr, }; - println!("{seek} {segment:?}"); + println!("{seek}: {segment:?}"); segments.push(segment) } Ok(()) diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 94607934..53ee6a90 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -1,3 +1,5 @@ +// We use anyhow rather than candle errors as it provides better support for getting the backtrace +// back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use std::collections::HashMap;