Basic multilingual support for whisper (#417)

* Multi-lingual support for whisper.

* Avoid hardcoding the token names.

* More multi-lingual support.

* Remove the todo.
This commit is contained in:
Laurent Mazare
2023-08-12 12:23:04 +02:00
committed by GitHub
parent 2ba6b2826f
commit 0c3f109faa
2 changed files with 175 additions and 13 deletions

View File

@ -1,7 +1,6 @@
// https://github.com/openai/whisper/blob/main/whisper/model.py // https://github.com/openai/whisper/blob/main/whisper/model.py
// TODO: // TODO:
// - kv-cache support? // - kv-cache support?
// - Language detection?
// - Batch size greater than 1. // - Batch size greater than 1.
// - More token filters (SuppressBlanks, ApplyTimestampRules). // - More token filters (SuppressBlanks, ApplyTimestampRules).
@ -19,6 +18,7 @@ use tokenizers::Tokenizer;
mod audio; mod audio;
mod model; mod model;
use model::{Config, Whisper}; use model::{Config, Whisper};
mod multilingual;
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
@ -37,9 +37,9 @@ const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits. // Tokenizer dependent bits.
const SOT_TOKEN: u32 = 50257; const SOT_TOKEN: &str = "<|startoftranscript|>";
const EOT_TOKEN: u32 = 50256; const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: u32 = 50361; const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
// From the _get_suppress_tokens function + 50362 (no timestamp) // From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [ const SUPPRESS_TOKENS: [u32; 91] = [
@ -75,6 +75,9 @@ struct Decoder {
rng: rand::rngs::StdRng, rng: rand::rngs::StdRng,
tokenizer: Tokenizer, tokenizer: Tokenizer,
suppress_tokens: Tensor, suppress_tokens: Tensor,
sot_token: u32,
eot_token: u32,
no_speech_token: u32,
} }
impl Decoder { impl Decoder {
@ -89,11 +92,17 @@ impl Decoder {
}) })
.collect(); .collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self { Ok(Self {
model, model,
rng: rand::rngs::StdRng::seed_from_u64(seed), rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer, tokenizer,
suppress_tokens, suppress_tokens,
sot_token,
eot_token,
no_speech_token,
}) })
} }
@ -104,7 +113,7 @@ impl Decoder {
let sample_len = model.config.max_target_positions / 2; let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64; let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN; let mut no_speech_prob = f64::NAN;
let mut tokens = vec![SOT_TOKEN]; let mut tokens = vec![self.sot_token];
for i in 0..sample_len { for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
@ -118,7 +127,7 @@ impl Decoder {
// token logits and the probability for the according token. // token logits and the probability for the according token.
if i == 0 { if i == 0 {
no_speech_prob = softmax(&logits.get(0)?, 0)? no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)? .get(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64; .to_scalar::<f32>()? as f64;
} }
@ -144,7 +153,7 @@ impl Decoder {
let prob = softmax(&logits, candle::D::Minus1)? let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)? .get(next_token as usize)?
.to_scalar::<f32>()? as f64; .to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
break; break;
} }
sum_logprob += prob.ln(); sum_logprob += prob.ln();
@ -216,19 +225,34 @@ impl Decoder {
} }
} }
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel { enum WhichModel {
Tiny, Tiny,
Small, TinyEn,
Medium, SmallEn,
MediumEn,
} }
impl WhichModel { impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny => true,
Self::TinyEn | Self::SmallEn | Self::MediumEn => false,
}
}
fn model_and_revision(&self) -> (&'static str, &'static str) { fn model_and_revision(&self) -> (&'static str, &'static str) {
match self { match self {
Self::Tiny => ("openai/whisper-tiny.en", "refs/pr/15"), Self::Tiny => ("openai/whisper-tiny", "main"),
Self::Small => ("openai/whisper-small.en", "refs/pr/10"), Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Medium => ("openai/whisper-medium.en", "refs/pr/11"), Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
} }
} }
} }
@ -249,7 +273,7 @@ struct Args {
revision: Option<String>, revision: Option<String>,
/// The model to be used, can be tiny, small, medium. /// The model to be used, can be tiny, small, medium.
#[arg(long, default_value = "tiny")] #[arg(long, default_value = "tiny-en")]
model: WhichModel, model: WhichModel,
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively /// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
@ -354,6 +378,10 @@ fn main() -> Result<()> {
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?; let model = Whisper::load(&vb, config)?;
if args.model.is_multilingual() {
multilingual::detect_language(&model, &tokenizer, &mel)?
}
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?; let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
dc.run(&mel)?; dc.run(&mel)?;
Ok(()) Ok(())

View File

@ -0,0 +1,134 @@
use crate::Whisper;
use candle::{IndexOp, Result, Tensor, D};
use tokenizers::Tokenizer;
const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<()> {
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
let audio_features = model.encoder.forward(mel)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
println!("{tokens}");
println!("{audio_features}");
let logits = model
.decoder
.forward(&tokens, &audio_features)?
.i(0)?
.i(0)?;
println!("{logits}");
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
Ok(())
}