More multilingual support for whisper. (#419)

* More multilingual support for whisper.

* Use the language token appropriately.
This commit is contained in:
Laurent Mazare
2023-08-12 16:32:52 +02:00
committed by GitHub
parent 0c3f109faa
commit 0741ebbd51
3 changed files with 47 additions and 23 deletions

View File

@ -38,18 +38,9 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits. // Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>"; const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const EOT_TOKEN: &str = "<|endoftext|>"; const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -76,15 +67,23 @@ struct Decoder {
tokenizer: Tokenizer, tokenizer: Tokenizer,
suppress_tokens: Tensor, suppress_tokens: Tensor,
sot_token: u32, sot_token: u32,
transcribe_token: u32,
eot_token: u32, eot_token: u32,
no_speech_token: u32, no_speech_token: u32,
language_token: Option<u32>,
} }
impl Decoder { impl Decoder {
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> { fn new(
model: Whisper,
tokenizer: Tokenizer,
seed: u64,
device: &Device,
language_token: Option<u32>,
) -> Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| { .map(|i| {
if SUPPRESS_TOKENS.contains(&i) { if model.config.suppress_tokens.contains(&i) {
f32::NEG_INFINITY f32::NEG_INFINITY
} else { } else {
0f32 0f32
@ -93,6 +92,7 @@ impl Decoder {
.collect(); .collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self { Ok(Self {
@ -101,8 +101,10 @@ impl Decoder {
tokenizer, tokenizer,
suppress_tokens, suppress_tokens,
sot_token, sot_token,
transcribe_token,
eot_token, eot_token,
no_speech_token, no_speech_token,
language_token,
}) })
} }
@ -114,6 +116,10 @@ impl Decoder {
let mut sum_logprob = 0f64; let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN; let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token]; let mut tokens = vec![self.sot_token];
if let Some(language_token) = self.language_token {
tokens.push(language_token)
}
tokens.push(self.transcribe_token);
for i in 0..sample_len { for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
@ -236,23 +242,29 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
enum WhichModel { enum WhichModel {
Tiny, Tiny,
TinyEn, TinyEn,
Base,
BaseEn,
SmallEn, SmallEn,
MediumEn, MediumEn,
LargeV2,
} }
impl WhichModel { impl WhichModel {
fn is_multilingual(&self) -> bool { fn is_multilingual(&self) -> bool {
match self { match self {
Self::Tiny => true, Self::Tiny | Self::Base | Self::LargeV2 => true,
Self::TinyEn | Self::SmallEn | Self::MediumEn => false, Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
} }
} }
fn model_and_revision(&self) -> (&'static str, &'static str) { fn model_and_revision(&self) -> (&'static str, &'static str) {
match self { match self {
Self::Tiny => ("openai/whisper-tiny", "main"), Self::Tiny => ("openai/whisper-tiny", "main"),
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"), Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Base => ("openai/whisper-base", "refs/pr/22"),
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"), Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"), Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
} }
} }
} }
@ -379,10 +391,12 @@ fn main() -> Result<()> {
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?; let model = Whisper::load(&vb, config)?;
if args.model.is_multilingual() { let language_token = if args.model.is_multilingual() {
multilingual::detect_language(&model, &tokenizer, &mel)? Some(multilingual::detect_language(&model, &tokenizer, &mel)?)
} } else {
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?; None
};
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
dc.run(&mel)?; dc.run(&mel)?;
Ok(()) Ok(())
} }

View File

@ -16,11 +16,21 @@ pub struct Config {
// pub n_text_state: usize, // pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer pub decoder_layers: usize, // n_text_layer
pub suppress_tokens: Vec<u32>,
} }
impl Config { impl Config {
#[allow(dead_code)] #[allow(dead_code)]
pub fn tiny_en() -> Self { pub fn tiny_en() -> Self {
let suppress_tokens = vec![
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93,
357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391,
1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329,
7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306,
16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409,
34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361,
50362,
];
Self { Self {
num_mel_bins: 80, num_mel_bins: 80,
vocab_size: 51864, vocab_size: 51864,
@ -32,6 +42,7 @@ impl Config {
// n_text_state: 384, // n_text_state: 384,
decoder_attention_heads: 6, decoder_attention_heads: 6,
decoder_layers: 4, decoder_layers: 4,
suppress_tokens,
} }
} }
} }

View File

@ -104,7 +104,8 @@ const LANGUAGES: [(&str, &str); 99] = [
("su", "sundanese"), ("su", "sundanese"),
]; ];
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<()> { /// Returns the token id for the selected language.
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
let device = mel.device(); let device = mel.device();
let language_token_ids = LANGUAGES let language_token_ids = LANGUAGES
.iter() .iter()
@ -114,14 +115,11 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) ->
let audio_features = model.encoder.forward(mel)?; let audio_features = model.encoder.forward(mel)?;
let tokens = Tensor::new(&[[sot_token]], device)?; let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
println!("{tokens}");
println!("{audio_features}");
let logits = model let logits = model
.decoder .decoder
.forward(&tokens, &audio_features)? .forward(&tokens, &audio_features)?
.i(0)? .i(0)?
.i(0)?; .i(0)?;
println!("{logits}");
let logits = logits.index_select(&language_token_ids, 0)?; let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?; let probs = probs.to_vec1::<f32>()?;
@ -130,5 +128,6 @@ pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) ->
for ((_, language), p) in probs.iter().take(5) { for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}") println!("{language}: {p}")
} }
Ok(()) let language = crate::token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
Ok(language)
} }