mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Basic multilingual support for whisper (#417)
* Multi-lingual support for whisper. * Avoid hardcoding the token names. * More multi-lingual support. * Remove the todo.
This commit is contained in:
@ -1,7 +1,6 @@
|
|||||||
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||||
// TODO:
|
// TODO:
|
||||||
// - kv-cache support?
|
// - kv-cache support?
|
||||||
// - Language detection?
|
|
||||||
// - Batch size greater than 1.
|
// - Batch size greater than 1.
|
||||||
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
||||||
|
|
||||||
@ -19,6 +18,7 @@ use tokenizers::Tokenizer;
|
|||||||
mod audio;
|
mod audio;
|
||||||
mod model;
|
mod model;
|
||||||
use model::{Config, Whisper};
|
use model::{Config, Whisper};
|
||||||
|
mod multilingual;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
@ -37,9 +37,9 @@ const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
|||||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||||
|
|
||||||
// Tokenizer dependent bits.
|
// Tokenizer dependent bits.
|
||||||
const SOT_TOKEN: u32 = 50257;
|
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||||
const EOT_TOKEN: u32 = 50256;
|
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||||
const NO_SPEECH_TOKEN: u32 = 50361;
|
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||||
const SUPPRESS_TOKENS: [u32; 91] = [
|
const SUPPRESS_TOKENS: [u32; 91] = [
|
||||||
@ -75,6 +75,9 @@ struct Decoder {
|
|||||||
rng: rand::rngs::StdRng,
|
rng: rand::rngs::StdRng,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
suppress_tokens: Tensor,
|
suppress_tokens: Tensor,
|
||||||
|
sot_token: u32,
|
||||||
|
eot_token: u32,
|
||||||
|
no_speech_token: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Decoder {
|
impl Decoder {
|
||||||
@ -89,11 +92,17 @@ impl Decoder {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||||
|
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||||
|
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||||
|
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
suppress_tokens,
|
suppress_tokens,
|
||||||
|
sot_token,
|
||||||
|
eot_token,
|
||||||
|
no_speech_token,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +113,7 @@ impl Decoder {
|
|||||||
let sample_len = model.config.max_target_positions / 2;
|
let sample_len = model.config.max_target_positions / 2;
|
||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
let mut no_speech_prob = f64::NAN;
|
let mut no_speech_prob = f64::NAN;
|
||||||
let mut tokens = vec![SOT_TOKEN];
|
let mut tokens = vec![self.sot_token];
|
||||||
for i in 0..sample_len {
|
for i in 0..sample_len {
|
||||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||||
|
|
||||||
@ -118,7 +127,7 @@ impl Decoder {
|
|||||||
// token logits and the probability for the according token.
|
// token logits and the probability for the according token.
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||||
.get(NO_SPEECH_TOKEN as usize)?
|
.get(self.no_speech_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,7 +153,7 @@ impl Decoder {
|
|||||||
let prob = softmax(&logits, candle::D::Minus1)?
|
let prob = softmax(&logits, candle::D::Minus1)?
|
||||||
.get(next_token as usize)?
|
.get(next_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sum_logprob += prob.ln();
|
sum_logprob += prob.ln();
|
||||||
@ -216,19 +225,34 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||||
|
match tokenizer.token_to_id(token) {
|
||||||
|
None => candle::bail!("no token-id for {token}"),
|
||||||
|
Some(id) => Ok(id),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
Tiny,
|
Tiny,
|
||||||
Small,
|
TinyEn,
|
||||||
Medium,
|
SmallEn,
|
||||||
|
MediumEn,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhichModel {
|
impl WhichModel {
|
||||||
|
fn is_multilingual(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Tiny => true,
|
||||||
|
Self::TinyEn | Self::SmallEn | Self::MediumEn => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
||||||
match self {
|
match self {
|
||||||
Self::Tiny => ("openai/whisper-tiny.en", "refs/pr/15"),
|
Self::Tiny => ("openai/whisper-tiny", "main"),
|
||||||
Self::Small => ("openai/whisper-small.en", "refs/pr/10"),
|
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
|
||||||
Self::Medium => ("openai/whisper-medium.en", "refs/pr/11"),
|
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
|
||||||
|
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -249,7 +273,7 @@ struct Args {
|
|||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// The model to be used, can be tiny, small, medium.
|
/// The model to be used, can be tiny, small, medium.
|
||||||
#[arg(long, default_value = "tiny")]
|
#[arg(long, default_value = "tiny-en")]
|
||||||
model: WhichModel,
|
model: WhichModel,
|
||||||
|
|
||||||
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
|
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
|
||||||
@ -354,6 +378,10 @@ fn main() -> Result<()> {
|
|||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let model = Whisper::load(&vb, config)?;
|
let model = Whisper::load(&vb, config)?;
|
||||||
|
|
||||||
|
if args.model.is_multilingual() {
|
||||||
|
multilingual::detect_language(&model, &tokenizer, &mel)?
|
||||||
|
}
|
||||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
|
let mut dc = Decoder::new(model, tokenizer, args.seed, &device)?;
|
||||||
dc.run(&mel)?;
|
dc.run(&mel)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
134
candle-examples/examples/whisper/multilingual.rs
Normal file
134
candle-examples/examples/whisper/multilingual.rs
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
use crate::Whisper;
|
||||||
|
use candle::{IndexOp, Result, Tensor, D};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
const LANGUAGES: [(&str, &str); 99] = [
|
||||||
|
("en", "english"),
|
||||||
|
("zh", "chinese"),
|
||||||
|
("de", "german"),
|
||||||
|
("es", "spanish"),
|
||||||
|
("ru", "russian"),
|
||||||
|
("ko", "korean"),
|
||||||
|
("fr", "french"),
|
||||||
|
("ja", "japanese"),
|
||||||
|
("pt", "portuguese"),
|
||||||
|
("tr", "turkish"),
|
||||||
|
("pl", "polish"),
|
||||||
|
("ca", "catalan"),
|
||||||
|
("nl", "dutch"),
|
||||||
|
("ar", "arabic"),
|
||||||
|
("sv", "swedish"),
|
||||||
|
("it", "italian"),
|
||||||
|
("id", "indonesian"),
|
||||||
|
("hi", "hindi"),
|
||||||
|
("fi", "finnish"),
|
||||||
|
("vi", "vietnamese"),
|
||||||
|
("he", "hebrew"),
|
||||||
|
("uk", "ukrainian"),
|
||||||
|
("el", "greek"),
|
||||||
|
("ms", "malay"),
|
||||||
|
("cs", "czech"),
|
||||||
|
("ro", "romanian"),
|
||||||
|
("da", "danish"),
|
||||||
|
("hu", "hungarian"),
|
||||||
|
("ta", "tamil"),
|
||||||
|
("no", "norwegian"),
|
||||||
|
("th", "thai"),
|
||||||
|
("ur", "urdu"),
|
||||||
|
("hr", "croatian"),
|
||||||
|
("bg", "bulgarian"),
|
||||||
|
("lt", "lithuanian"),
|
||||||
|
("la", "latin"),
|
||||||
|
("mi", "maori"),
|
||||||
|
("ml", "malayalam"),
|
||||||
|
("cy", "welsh"),
|
||||||
|
("sk", "slovak"),
|
||||||
|
("te", "telugu"),
|
||||||
|
("fa", "persian"),
|
||||||
|
("lv", "latvian"),
|
||||||
|
("bn", "bengali"),
|
||||||
|
("sr", "serbian"),
|
||||||
|
("az", "azerbaijani"),
|
||||||
|
("sl", "slovenian"),
|
||||||
|
("kn", "kannada"),
|
||||||
|
("et", "estonian"),
|
||||||
|
("mk", "macedonian"),
|
||||||
|
("br", "breton"),
|
||||||
|
("eu", "basque"),
|
||||||
|
("is", "icelandic"),
|
||||||
|
("hy", "armenian"),
|
||||||
|
("ne", "nepali"),
|
||||||
|
("mn", "mongolian"),
|
||||||
|
("bs", "bosnian"),
|
||||||
|
("kk", "kazakh"),
|
||||||
|
("sq", "albanian"),
|
||||||
|
("sw", "swahili"),
|
||||||
|
("gl", "galician"),
|
||||||
|
("mr", "marathi"),
|
||||||
|
("pa", "punjabi"),
|
||||||
|
("si", "sinhala"),
|
||||||
|
("km", "khmer"),
|
||||||
|
("sn", "shona"),
|
||||||
|
("yo", "yoruba"),
|
||||||
|
("so", "somali"),
|
||||||
|
("af", "afrikaans"),
|
||||||
|
("oc", "occitan"),
|
||||||
|
("ka", "georgian"),
|
||||||
|
("be", "belarusian"),
|
||||||
|
("tg", "tajik"),
|
||||||
|
("sd", "sindhi"),
|
||||||
|
("gu", "gujarati"),
|
||||||
|
("am", "amharic"),
|
||||||
|
("yi", "yiddish"),
|
||||||
|
("lo", "lao"),
|
||||||
|
("uz", "uzbek"),
|
||||||
|
("fo", "faroese"),
|
||||||
|
("ht", "haitian creole"),
|
||||||
|
("ps", "pashto"),
|
||||||
|
("tk", "turkmen"),
|
||||||
|
("nn", "nynorsk"),
|
||||||
|
("mt", "maltese"),
|
||||||
|
("sa", "sanskrit"),
|
||||||
|
("lb", "luxembourgish"),
|
||||||
|
("my", "myanmar"),
|
||||||
|
("bo", "tibetan"),
|
||||||
|
("tl", "tagalog"),
|
||||||
|
("mg", "malagasy"),
|
||||||
|
("as", "assamese"),
|
||||||
|
("tt", "tatar"),
|
||||||
|
("haw", "hawaiian"),
|
||||||
|
("ln", "lingala"),
|
||||||
|
("ha", "hausa"),
|
||||||
|
("ba", "bashkir"),
|
||||||
|
("jw", "javanese"),
|
||||||
|
("su", "sundanese"),
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn detect_language(model: &Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<()> {
|
||||||
|
let device = mel.device();
|
||||||
|
let language_token_ids = LANGUAGES
|
||||||
|
.iter()
|
||||||
|
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
||||||
|
let audio_features = model.encoder.forward(mel)?;
|
||||||
|
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||||
|
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||||
|
println!("{tokens}");
|
||||||
|
println!("{audio_features}");
|
||||||
|
let logits = model
|
||||||
|
.decoder
|
||||||
|
.forward(&tokens, &audio_features)?
|
||||||
|
.i(0)?
|
||||||
|
.i(0)?;
|
||||||
|
println!("{logits}");
|
||||||
|
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||||
|
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||||
|
let probs = probs.to_vec1::<f32>()?;
|
||||||
|
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
||||||
|
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for ((_, language), p) in probs.iter().take(5) {
|
||||||
|
println!("{language}: {p}")
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user