mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
More multilingual support for whisper. (#419)
* More multilingual support for whisper. * Use the language token appropriately.
This commit is contained in:
@ -38,18 +38,9 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||
const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
|
||||
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
|
||||
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
|
||||
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
|
||||
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
|
||||
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
||||
];
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
@ -76,15 +67,23 @@ struct Decoder {
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
transcribe_token: u32,
|
||||
eot_token: u32,
|
||||
no_speech_token: u32,
|
||||
language_token: Option<u32>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn new(model: Whisper, tokenizer: Tokenizer, seed: u64, device: &Device) -> Result<Self> {
|
||||
fn new(
|
||||
model: Whisper,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
language_token: Option<u32>,
|
||||
) -> Result<Self> {
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
.map(|i| {
|
||||
if SUPPRESS_TOKENS.contains(&i) {
|
||||
if model.config.suppress_tokens.contains(&i) {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0f32
|
||||
@ -93,6 +92,7 @@ impl Decoder {
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
@ -101,8 +101,10 @@ impl Decoder {
|
||||
tokenizer,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
eot_token,
|
||||
no_speech_token,
|
||||
language_token,
|
||||
})
|
||||
}
|
||||
|
||||
@ -114,6 +116,10 @@ impl Decoder {
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token];
|
||||
if let Some(language_token) = self.language_token {
|
||||
tokens.push(language_token)
|
||||
}
|
||||
tokens.push(self.transcribe_token);
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
@ -236,23 +242,29 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
TinyEn,
|
||||
Base,
|
||||
BaseEn,
|
||||
SmallEn,
|
||||
MediumEn,
|
||||
LargeV2,
|
||||
}
|
||||
|
||||
impl WhichModel {
|
||||
fn is_multilingual(&self) -> bool {
|
||||
match self {
|
||||
Self::Tiny => true,
|
||||
Self::TinyEn | Self::SmallEn | Self::MediumEn => false,
|
||||
Self::Tiny | Self::Base | Self::LargeV2 => true,
|
||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
|
||||
}
|
||||
}
|
||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||
match self {
|
||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
|
||||
Self::Base => ("openai/whisper-base", "refs/pr/22"),
|
||||
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
|
||||
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
|
||||
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
|
||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -379,10 +391,12 @@ fn main() -> Result<()> {
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let model = Whisper::load(&vb, config)?;
|
||||
|
||||
if args.model.is_multilingual() {
|
||||
multilingual::detect_language(&model, &tokenizer, &mel)?
|
||||
}
|
||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
|
||||
let language_token = if args.model.is_multilingual() {
|
||||
Some(multilingual::detect_language(&model, &tokenizer, &mel)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
|
||||
dc.run(&mel)?;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user