From cd230d26fecb2ba69352a125d8ba1a4e75f3e6d1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 6 Jul 2023 09:13:20 +0100 Subject: [PATCH] Whisper tweaks (#85) * Isolate the decoding bits of the whisper example. * Decode -> Decoder. * Add the suppress tokens filter. * More suppress tokens. --- candle-examples/examples/whisper/main.rs | 107 ++++++++++++++-------- candle-examples/examples/whisper/model.rs | 4 +- 2 files changed, 72 insertions(+), 39 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 79d47a39..7679f1a2 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -4,6 +4,7 @@ // - kv-cache support? // - Language detection? // - Batch size greater than 1. +// - More token filters (SuppressBlanks, ApplyTimestampRules). use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; @@ -40,6 +41,16 @@ const SOT_TOKEN: u32 = 50257; const EOT_TOKEN: u32 = 50256; const NO_SPEECH_TOKEN: u32 = 50361; const NO_TIMESTAMP_TOKEN: u32 = 50362; +// From the _get_suppress_tokens function + 50362 (no timestamp) +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 +const SUPPRESS_TOKENS: [u32; 91] = [ + 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, + 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, + 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, + 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, + 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, + 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, +]; #[derive(Debug, Clone)] struct DecodingResult { @@ -58,13 +69,33 @@ struct Segment { dr: DecodingResult, } -struct Decode { +struct Decoder { model: Whisper, rng: rand::rngs::StdRng, tokenizer: Tokenizer, + suppress_tokens: Tensor, } -impl Decode { +impl Decoder { + fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result { + let suppress_tokens: Vec = (0..model.config.vocab_size as u32) + .map(|i| { + if SUPPRESS_TOKENS.contains(&i) { + f32::NEG_INFINITY + } else { + 0f32 + } + }) + .collect(); + let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; + Ok(Self { + model, + rng: rand::rngs::StdRng::seed_from_u64(seed), + tokenizer, + suppress_tokens, + }) + } + fn decode(&mut self, mel: &Tensor, t: f64) -> Result { let model = &self.model; let audio_features = model.encoder.forward(mel)?; @@ -93,7 +124,9 @@ impl Decode { } let (seq_len, _) = logits.shape().r2()?; - let logits = logits.get(seq_len - 1)?; + let logits = logits + .get(seq_len - 1)? + .broadcast_add(&self.suppress_tokens)?; let next_token = if t > 0f64 { let prs = (&logits / t)?.softmax(0)?; let logits_v: Vec = prs.to_vec1()?; @@ -156,6 +189,33 @@ impl Decode { } unreachable!() } + + fn run(&mut self, mel: &Tensor) -> Result> { + let (_, _, content_frames) = mel.shape().r3()?; + let mut seek = 0; + let mut segments = vec![]; + let start = std::time::Instant::now(); + while seek < content_frames { + let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, N_FRAMES); + let mel_segment = mel.narrow(2, seek, segment_size)?; + let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let dr = self.decode_with_fallback(&mel_segment)?; + seek += segment_size; + if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + println!("no speech detected, skipping {seek} {dr:?}"); + continue; + } + let segment = Segment { + start: time_offset, + duration: segment_duration, + dr, + }; + println!("{seek}: {segment:?} : Took {:?}", start.elapsed()); + segments.push(segment) + } + Ok(segments) + } } #[derive(Parser, Debug)] @@ -168,11 +228,13 @@ struct Args { #[arg(long)] model_id: Option, - /// The model to use, check out available models: https://huggingface.co/models?search=whisper + /// The model to use, check out available models: + /// https://huggingface.co/models?search=whisper #[arg(long)] revision: Option, - /// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav + /// The input to be processed, in wav formats, will default to `jfk.wav` + /// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav #[arg(long)] input: Option, @@ -196,8 +258,6 @@ async fn main() -> Result<()> { } else { Device::new_cuda(0)? }; - let rng = rand::rngs::StdRng::seed_from_u64(args.seed); - let default_model = "openai/whisper-tiny.en".to_string(); let path = std::path::PathBuf::from(default_model.clone()); let default_revision = "refs/pr/15".to_string(); @@ -267,37 +327,10 @@ async fn main() -> Result<()> { let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device); + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let model = Whisper::load(&vb, config)?; - let mut dc = Decode { - model, - rng, - tokenizer, - }; - - let (_, _, content_frames) = mel.shape().r3()?; - let mut seek = 0; - let mut segments = vec![]; - let start = std::time::Instant::now(); - while seek < content_frames { - let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let segment_size = usize::min(content_frames - seek, N_FRAMES); - let mel_segment = mel.narrow(2, seek, segment_size)?; - let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let dr = dc.decode_with_fallback(&mel_segment)?; - seek += segment_size; - if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { - println!("no speech detected, skipping {seek} {dr:?}"); - continue; - } - let segment = Segment { - start: time_offset, - duration: segment_duration, - dr, - }; - println!("{seek}: {segment:?} : Took {:?}", start.elapsed()); - segments.push(segment) - } + let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?; + dc.run(&mel)?; Ok(()) } diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 3de150d6..e589e231 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -15,7 +15,7 @@ impl<'a> VarBuilder<'a> { pub fn from_safetensors( safetensors: Vec>, dtype: DType, - device: Device, + device: &Device, ) -> Self { let mut routing = HashMap::new(); for (index, sf) in safetensors.iter().enumerate() { @@ -25,7 +25,7 @@ impl<'a> VarBuilder<'a> { } Self { safetensors: Some((routing, safetensors)), - device, + device: device.clone(), dtype, } }