Improve the timestamps support in whisper (#539)

* Timestamp support for whisper.

* Properly display the timestamps.

* Bugfix for the timestamp units.
This commit is contained in:
Laurent Mazare
2023-08-21 12:26:59 +01:00
committed by GitHub
parent e3b71851e6
commit cc2d6cf2e0

View File

@ -70,6 +70,7 @@ struct Decoder {
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
verbose: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
@ -82,6 +83,7 @@ struct Decoder {
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Whisper,
tokenizer: Tokenizer,
@ -90,10 +92,16 @@ impl Decoder {
language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
verbose: bool,
) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| {
if model.config.suppress_tokens.contains(&i) {
if model.config.suppress_tokens.contains(&i)
|| timestamps && i == no_timestamps_token
{
f32::NEG_INFINITY
} else {
0f32
@ -104,7 +112,6 @@ impl Decoder {
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
@ -113,6 +120,7 @@ impl Decoder {
tokenizer,
task,
timestamps,
verbose,
suppress_tokens,
sot_token,
transcribe_token,
@ -127,7 +135,9 @@ impl Decoder {
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
println!("audio features: {:?}", audio_features.dims());
if self.verbose {
println!("audio features: {:?}", audio_features.dims());
}
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
@ -168,6 +178,13 @@ impl Decoder {
.final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
@ -249,7 +266,55 @@ impl Decoder {
duration: segment_duration,
dr,
};
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
if self.timestamps {
println!(
"{:.1}s -- {:.1}s",
segment.start,
segment.start + segment.duration,
);
let mut tokens_to_decode = vec![];
let mut prev_timestamp_s = 0f32;
for &token in segment.dr.tokens.iter() {
if token == self.sot_token || token == self.eot_token {
continue;
}
// The no_timestamp_token is the last before the timestamp ones.
if token > self.no_timestamps_token {
let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
if !tokens_to_decode.is_empty() {
let text = self
.tokenizer
.decode(&tokens_to_decode, true)
.map_err(E::msg)?;
println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
tokens_to_decode.clear()
}
prev_timestamp_s = timestamp_s;
} else {
tokens_to_decode.push(token)
}
}
if !tokens_to_decode.is_empty() {
let text = self
.tokenizer
.decode(&tokens_to_decode, true)
.map_err(E::msg)?;
if !text.is_empty() {
println!(" {:.1}s-...: {}", prev_timestamp_s, text);
}
tokens_to_decode.clear()
}
} else {
println!(
"{:.1}s -- {:.1}s: {}",
segment.start,
segment.start + segment.duration,
segment.dr.text,
)
}
if self.verbose {
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
}
segments.push(segment)
}
Ok(segments)
@ -357,6 +422,10 @@ struct Args {
/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
timestamps: bool,
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
}
fn main() -> Result<()> {
@ -466,6 +535,7 @@ fn main() -> Result<()> {
language_token,
args.task,
args.timestamps,
args.verbose,
)?;
dc.run(&mel)?;
Ok(())