mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Improve the timestamps support in whisper (#539)
* Timestamp support for whisper. * Properly display the timestamps. * Bugfix for the timestamp units.
This commit is contained in:
@ -70,6 +70,7 @@ struct Decoder {
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
@ -82,6 +83,7 @@ struct Decoder {
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Whisper,
|
||||
tokenizer: Tokenizer,
|
||||
@ -90,10 +92,16 @@ impl Decoder {
|
||||
language_token: Option<u32>,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
) -> Result<Self> {
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
.map(|i| {
|
||||
if model.config.suppress_tokens.contains(&i) {
|
||||
if model.config.suppress_tokens.contains(&i)
|
||||
|| timestamps && i == no_timestamps_token
|
||||
{
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0f32
|
||||
@ -104,7 +112,6 @@ impl Decoder {
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
@ -113,6 +120,7 @@ impl Decoder {
|
||||
tokenizer,
|
||||
task,
|
||||
timestamps,
|
||||
verbose,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
@ -127,7 +135,9 @@ impl Decoder {
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
if self.verbose {
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
}
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
@ -168,6 +178,13 @@ impl Decoder {
|
||||
.final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||
// ApplyTimestampRules, i.e.:
|
||||
// - Timestamps come in pairs, except before EOT.
|
||||
// - Timestamps should be non-decreasing.
|
||||
// - If the sum of the probabilities of timestamps is higher than any other tokens,
|
||||
// only consider timestamps when sampling.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
|
||||
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
@ -249,7 +266,55 @@ impl Decoder {
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
|
||||
if self.timestamps {
|
||||
println!(
|
||||
"{:.1}s -- {:.1}s",
|
||||
segment.start,
|
||||
segment.start + segment.duration,
|
||||
);
|
||||
let mut tokens_to_decode = vec![];
|
||||
let mut prev_timestamp_s = 0f32;
|
||||
for &token in segment.dr.tokens.iter() {
|
||||
if token == self.sot_token || token == self.eot_token {
|
||||
continue;
|
||||
}
|
||||
// The no_timestamp_token is the last before the timestamp ones.
|
||||
if token > self.no_timestamps_token {
|
||||
let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
|
||||
if !tokens_to_decode.is_empty() {
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(&tokens_to_decode, true)
|
||||
.map_err(E::msg)?;
|
||||
println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
|
||||
tokens_to_decode.clear()
|
||||
}
|
||||
prev_timestamp_s = timestamp_s;
|
||||
} else {
|
||||
tokens_to_decode.push(token)
|
||||
}
|
||||
}
|
||||
if !tokens_to_decode.is_empty() {
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(&tokens_to_decode, true)
|
||||
.map_err(E::msg)?;
|
||||
if !text.is_empty() {
|
||||
println!(" {:.1}s-...: {}", prev_timestamp_s, text);
|
||||
}
|
||||
tokens_to_decode.clear()
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"{:.1}s -- {:.1}s: {}",
|
||||
segment.start,
|
||||
segment.start + segment.duration,
|
||||
segment.dr.text,
|
||||
)
|
||||
}
|
||||
if self.verbose {
|
||||
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
|
||||
}
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(segments)
|
||||
@ -357,6 +422,10 @@ struct Args {
|
||||
/// Timestamps mode, this is not fully implemented yet.
|
||||
#[arg(long)]
|
||||
timestamps: bool,
|
||||
|
||||
/// Print the full DecodingResult structure rather than just the text.
|
||||
#[arg(long)]
|
||||
verbose: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -466,6 +535,7 @@ fn main() -> Result<()> {
|
||||
language_token,
|
||||
args.task,
|
||||
args.timestamps,
|
||||
args.verbose,
|
||||
)?;
|
||||
dc.run(&mel)?;
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user