Whisper tweaks (#85)

* Isolate the decoding bits of the whisper example.

* Decode -> Decoder.

* Add the suppress tokens filter.

* More suppress tokens.
This commit is contained in:
Laurent Mazare
2023-07-06 09:13:20 +01:00
committed by GitHub
parent be9b493a75
commit cd230d26fe
2 changed files with 72 additions and 39 deletions

View File

@ -4,6 +4,7 @@
// - kv-cache support?
// - Language detection?
// - Batch size greater than 1.
// - More token filters (SuppressBlanks, ApplyTimestampRules).
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
@ -40,6 +41,16 @@ const SOT_TOKEN: u32 = 50257;
const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
const NO_TIMESTAMP_TOKEN: u32 = 50362;
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
#[derive(Debug, Clone)]
struct DecodingResult {
@ -58,13 +69,33 @@ struct Segment {
dr: DecodingResult,
}
struct Decode {
struct Decoder {
model: Whisper,
rng: rand::rngs::StdRng,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
}
impl Decode {
impl Decoder {
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| {
if SUPPRESS_TOKENS.contains(&i) {
f32::NEG_INFINITY
} else {
0f32
}
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer,
suppress_tokens,
})
}
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
@ -93,7 +124,9 @@ impl Decode {
}
let (seq_len, _) = logits.shape().r2()?;
let logits = logits.get(seq_len - 1)?;
let logits = logits
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = (&logits / t)?.softmax(0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
@ -156,6 +189,33 @@ impl Decode {
}
unreachable!()
}
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0;
let mut segments = vec![];
let start = std::time::Instant::now();
while seek < content_frames {
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
let segment = Segment {
start: time_offset,
duration: segment_duration,
dr,
};
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
segments.push(segment)
}
Ok(segments)
}
}
#[derive(Parser, Debug)]
@ -168,11 +228,13 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
/// The model to use, check out available models: https://huggingface.co/models?search=whisper
/// The model to use, check out available models:
/// https://huggingface.co/models?search=whisper
#[arg(long)]
revision: Option<String>,
/// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
/// The input to be processed, in wav formats, will default to `jfk.wav`
/// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
#[arg(long)]
input: Option<String>,
@ -196,8 +258,6 @@ async fn main() -> Result<()> {
} else {
Device::new_cuda(0)?
};
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string();
@ -267,37 +327,10 @@ async fn main() -> Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?;
let mut dc = Decode {
model,
rng,
tokenizer,
};
let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0;
let mut segments = vec![];
let start = std::time::Instant::now();
while seek < content_frames {
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let dr = dc.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
let segment = Segment {
start: time_offset,
duration: segment_duration,
dr,
};
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
segments.push(segment)
}
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
dc.run(&mel)?;
Ok(())
}