More multilingual support for whisper. (#419)

* More multilingual support for whisper.

* Use the language token appropriately.
This commit is contained in:
Laurent Mazare
2023-08-12 16:32:52 +02:00
committed by GitHub
parent 0c3f109faa
commit 0741ebbd51
3 changed files with 47 additions and 23 deletions

View File

@ -38,18 +38,9 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
const SUPPRESS_TOKENS: [u32; 91] = [
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
#[allow(dead_code)]
#[derive(Debug, Clone)]
@ -76,15 +67,23 @@ struct Decoder {
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
eot_token: u32,
no_speech_token: u32,
language_token: Option<u32>,
}
impl Decoder {
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> {
fn new(
model: Whisper,
tokenizer: Tokenizer,
seed: u64,
device: &Device,
language_token: Option<u32>,
) -> Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| {
if SUPPRESS_TOKENS.contains(&i) {
if model.config.suppress_tokens.contains(&i) {
f32::NEG_INFINITY
} else {
0f32
@ -93,6 +92,7 @@ impl Decoder {
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
@ -101,8 +101,10 @@ impl Decoder {
tokenizer,
suppress_tokens,
sot_token,
transcribe_token,
eot_token,
no_speech_token,
language_token,
})
}
@ -114,6 +116,10 @@ impl Decoder {
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
if let Some(language_token) = self.language_token {
tokens.push(language_token)
}
tokens.push(self.transcribe_token);
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
@ -236,23 +242,29 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
enum WhichModel {
Tiny,
TinyEn,
Base,
BaseEn,
SmallEn,
MediumEn,
LargeV2,
}
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny => true,
Self::TinyEn | Self::SmallEn | Self::MediumEn => false,
Self::Tiny | Self::Base | Self::LargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}
fn model_and_revision(&self) -> (&'static str, &'static str) {
match self {
Self::Tiny => ("openai/whisper-tiny", "main"),
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Base => ("openai/whisper-base", "refs/pr/22"),
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
}
}
}
@ -379,10 +391,12 @@ fn main() -> Result<()> {
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?;
if args.model.is_multilingual() {
multilingual::detect_language(&model, &tokenizer, &mel)?
}
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
let language_token = if args.model.is_multilingual() {
Some(multilingual::detect_language(&model, &tokenizer, &mel)?)
} else {
None
};
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
dc.run(&mel)?;
Ok(())
}