From 27e70a50939b647a7c2e80428647f5668e592607 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radam=C3=A9s=20Ajna?= Date: Wed, 4 Oct 2023 12:22:57 -0700 Subject: [PATCH] Whisper quantized wasm (#1028) * [Whisper] Update to use quantized model * [whisper] add language detection * [whisper] change assets location * [whisper] adapt js example with quantized models * [whisper] better task parsing * [whisper] minor fixes --- .gitignore | 7 +- candle-wasm-examples/whisper/Cargo.toml | 1 + candle-wasm-examples/whisper/README.md | 30 +- candle-wasm-examples/whisper/index.html | 44 +- candle-wasm-examples/whisper/lib-example.html | 90 ++-- candle-wasm-examples/whisper/src/app.rs | 45 +- candle-wasm-examples/whisper/src/audio.rs | 8 +- candle-wasm-examples/whisper/src/bin/m.rs | 14 +- candle-wasm-examples/whisper/src/languages.rs | 101 +++++ candle-wasm-examples/whisper/src/lib.rs | 18 +- candle-wasm-examples/whisper/src/model.rs | 417 ------------------ candle-wasm-examples/whisper/src/worker.rs | 290 ++++++++---- candle-wasm-examples/whisper/whisperWorker.js | 71 ++- 13 files changed, 540 insertions(+), 596 deletions(-) create mode 100644 candle-wasm-examples/whisper/src/languages.rs delete mode 100644 candle-wasm-examples/whisper/src/model.rs diff --git a/.gitignore b/.gitignore index d0a8c320..9a112a61 100644 --- a/.gitignore +++ b/.gitignore @@ -29,9 +29,10 @@ trace-*.json candle-wasm-examples/*/build candle-wasm-examples/*/*.bin candle-wasm-examples/*/*.jpeg -candle-wasm-examples/*/*.wav -candle-wasm-examples/*/*.safetensors +candle-wasm-examples/*/audios/*.wav +candle-wasm-examples/**/*.safetensors +candle-wasm-examples/**/*.gguf candle-wasm-examples/*/package-lock.json - +candle-wasm-examples/**/config*.json .DS_Store .idea/* diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index bf9a4b34..9e01cd26 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true [dependencies] candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../../candle-nn", version = "0.3.0" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.0" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/README.md b/candle-wasm-examples/whisper/README.md index b847a965..85a52340 100644 --- a/candle-wasm-examples/whisper/README.md +++ b/candle-wasm-examples/whisper/README.md @@ -10,19 +10,31 @@ From the `candle-wasm-examples/whisper` directory run: Download assets: ```bash -# Model and tokenizer +# mel filters wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors -wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tiny.en.safetensors -wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tokenizer.en.json +# Model and tokenizer tiny.en +wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en +wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en +wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en +# model and tokenizer tiny multilanguage +wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny +wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny +wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny + +#quantized +wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized +wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized +wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized + # Audio samples -wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -O gb0.wav -wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -O a13.wav -wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -O gb1.wav -wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -O hp0.wav -wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -O jfk.wav -wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -O mm0.wav +wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios +wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios +wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios +wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios +wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios +wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios ``` diff --git a/candle-wasm-examples/whisper/index.html b/candle-wasm-examples/whisper/index.html index 7a21c4f2..98428205 100644 --- a/candle-wasm-examples/whisper/index.html +++ b/candle-wasm-examples/whisper/index.html @@ -3,22 +3,38 @@ Welcome to Candle! - - - - - - - - - - - + + + + + + + + - - - + + + + + + diff --git a/candle-wasm-examples/whisper/lib-example.html b/candle-wasm-examples/whisper/lib-example.html index 3cfd87a7..eb7e953f 100644 --- a/candle-wasm-examples/whisper/lib-example.html +++ b/candle-wasm-examples/whisper/lib-example.html @@ -26,9 +26,30 @@ // models base url const MODELS = { + tiny_multilingual: { + base_url: "https://huggingface.co/openai/whisper-tiny/resolve/main/", + model: "model.safetensors", + tokenizer: "tokenizer.json", + config: "config.json", + }, tiny_en: { base_url: - "https://huggingface.co/openai/whisper-tiny.en/resolve/refs%2Fpr%2F17/", + "https://huggingface.co/openai/whisper-tiny.en/resolve/main/", + model: "model.safetensors", + tokenizer: "tokenizer.json", + config: "config.json", + }, + tiny_quantized_multilingual_q80: { + base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/", + model: "model-tiny-q80.gguf", + tokenizer: "tokenizer-tiny.json", + config: "config-tiny.json", + }, + tiny_en_quantized_q80: { + base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/", + model: "model-tiny-q80.gguf", + tokenizer: "tokenizer-tiny-en.json", + config: "config-tiny-en.json", }, }; const whisperWorker = new Worker("./whisperWorker.js", { @@ -39,6 +60,7 @@ weightsURL, // URL to the weights file modelID, // model ID tokenizerURL, // URL to the tokenizer file + configURL, // model config URL mel_filtersURL, // URL to the mel filters file audioURL, // URL to the audio file updateStatus // function to update the status @@ -48,6 +70,7 @@ weightsURL, modelID, tokenizerURL, + configURL, mel_filtersURL, audioURL, }); @@ -128,13 +151,16 @@ return; } const modelID = document.querySelector("#model").value; - const modelURL = MODELS[modelID].base_url + "model.safetensors"; - const tokenizerURL = MODELS[modelID].base_url + "tokenizer.json"; + const model = MODELS[modelID]; + const modelURL = model.base_url + model.model; + const tokenizerURL = model.base_url + model.tokenizer; + const configURL = model.base_url + model.config; classifyAudio( modelURL, modelID, tokenizerURL, + configURL, "mel_filters.safetensors", audioURL, updateStatus @@ -178,8 +204,7 @@ + class="underline hover:text-blue-500 hover:no-underline"> OpenAI Whisper models and WASM runtime built with @@ -196,37 +221,38 @@
+ class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden">
+ class="flex flex-col items-center justify-center space-y-1 text-center"> + xmlns="http://www.w3.org/2000/svg"> + fill="#000" />
+ class="w-full p-2 select-none">
@@ -253,43 +277,37 @@

Examples:

@@ -300,16 +318,14 @@

Transcription:

+ class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"> No transcription results yet Result, JsValue> { @@ -54,14 +59,46 @@ pub struct App { } async fn model_data_load() -> Result { - let tokenizer = fetch_url("tokenizer.en.json").await?; - let mel_filters = fetch_url("mel_filters.safetensors").await?; - let weights = fetch_url("tiny.en.safetensors").await?; + let quantized = false; + let is_multilingual = false; + + let (tokenizer, mel_filters, weights, config) = if quantized { + console_log!("loading quantized weights"); + let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?; + let mel_filters = fetch_url("mel_filters.safetensors").await?; + let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?; + let config = fetch_url("quantized/config-tiny-en.json").await?; + (tokenizer, mel_filters, weights, config) + } else { + console_log!("loading float weights"); + if is_multilingual { + let mel_filters = fetch_url("mel_filters.safetensors").await?; + let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?; + let weights = fetch_url("whisper-tiny/model.safetensors").await?; + let config = fetch_url("whisper-tiny/config.json").await?; + (tokenizer, mel_filters, weights, config) + } else { + let mel_filters = fetch_url("mel_filters.safetensors").await?; + let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?; + let weights = fetch_url("whisper-tiny.en/model.safetensors").await?; + let config = fetch_url("whisper-tiny.en/config.json").await?; + (tokenizer, mel_filters, weights, config) + } + }; + + let timestamps = true; + let _task = Some("transcribe".to_string()); console_log!("{}", weights.len()); Ok(ModelData { tokenizer, mel_filters, weights, + config, + quantized, + timestamps, + task: None, + is_multilingual, + language: None, }) } diff --git a/candle-wasm-examples/whisper/src/audio.rs b/candle-wasm-examples/whisper/src/audio.rs index 8c0e682c..10974d15 100644 --- a/candle-wasm-examples/whisper/src/audio.rs +++ b/candle-wasm-examples/whisper/src/audio.rs @@ -168,7 +168,7 @@ fn log_mel_spectrogram_( let n_len = samples.len() / fft_step; // pad audio with at least one extra chunk of zeros - let pad = 100 * worker::CHUNK_LENGTH / 2; + let pad = 100 * worker::m::CHUNK_LENGTH / 2; let n_len = if n_len % pad != 0 { (n_len / pad + 1) * pad } else { @@ -206,9 +206,9 @@ pub fn pcm_to_mel( let mel = log_mel_spectrogram_( samples, filters, - worker::N_FFT, - worker::HOP_LENGTH, - worker::N_MELS, + worker::m::N_FFT, + worker::m::HOP_LENGTH, + worker::m::N_MELS, false, ); Ok(mel) diff --git a/candle-wasm-examples/whisper/src/bin/m.rs b/candle-wasm-examples/whisper/src/bin/m.rs index 0716a20d..67b7a189 100644 --- a/candle-wasm-examples/whisper/src/bin/m.rs +++ b/candle-wasm-examples/whisper/src/bin/m.rs @@ -9,15 +9,28 @@ pub struct Decoder { #[wasm_bindgen] impl Decoder { #[wasm_bindgen(constructor)] + #[allow(clippy::too_many_arguments)] pub fn new( weights: Vec, tokenizer: Vec, mel_filters: Vec, + config: Vec, + quantized: bool, + is_multilingual: bool, + timestamps: bool, + task: Option, + language: Option, ) -> Result { let decoder = D::load(ModelData { tokenizer, mel_filters, + config, + quantized, weights, + is_multilingual, + timestamps, + task, + language, }); match decoder { @@ -32,7 +45,6 @@ impl Decoder { .decoder .convert_and_run(&wav_input) .map_err(|e| JsError::new(&e.to_string()))?; - let json = serde_json::to_string(&segments)?; Ok(json) } diff --git a/candle-wasm-examples/whisper/src/languages.rs b/candle-wasm-examples/whisper/src/languages.rs new file mode 100644 index 00000000..fcbf9a7c --- /dev/null +++ b/candle-wasm-examples/whisper/src/languages.rs @@ -0,0 +1,101 @@ +pub const LANGUAGES: [(&str, &str); 99] = [ + ("en", "english"), + ("zh", "chinese"), + ("de", "german"), + ("es", "spanish"), + ("ru", "russian"), + ("ko", "korean"), + ("fr", "french"), + ("ja", "japanese"), + ("pt", "portuguese"), + ("tr", "turkish"), + ("pl", "polish"), + ("ca", "catalan"), + ("nl", "dutch"), + ("ar", "arabic"), + ("sv", "swedish"), + ("it", "italian"), + ("id", "indonesian"), + ("hi", "hindi"), + ("fi", "finnish"), + ("vi", "vietnamese"), + ("he", "hebrew"), + ("uk", "ukrainian"), + ("el", "greek"), + ("ms", "malay"), + ("cs", "czech"), + ("ro", "romanian"), + ("da", "danish"), + ("hu", "hungarian"), + ("ta", "tamil"), + ("no", "norwegian"), + ("th", "thai"), + ("ur", "urdu"), + ("hr", "croatian"), + ("bg", "bulgarian"), + ("lt", "lithuanian"), + ("la", "latin"), + ("mi", "maori"), + ("ml", "malayalam"), + ("cy", "welsh"), + ("sk", "slovak"), + ("te", "telugu"), + ("fa", "persian"), + ("lv", "latvian"), + ("bn", "bengali"), + ("sr", "serbian"), + ("az", "azerbaijani"), + ("sl", "slovenian"), + ("kn", "kannada"), + ("et", "estonian"), + ("mk", "macedonian"), + ("br", "breton"), + ("eu", "basque"), + ("is", "icelandic"), + ("hy", "armenian"), + ("ne", "nepali"), + ("mn", "mongolian"), + ("bs", "bosnian"), + ("kk", "kazakh"), + ("sq", "albanian"), + ("sw", "swahili"), + ("gl", "galician"), + ("mr", "marathi"), + ("pa", "punjabi"), + ("si", "sinhala"), + ("km", "khmer"), + ("sn", "shona"), + ("yo", "yoruba"), + ("so", "somali"), + ("af", "afrikaans"), + ("oc", "occitan"), + ("ka", "georgian"), + ("be", "belarusian"), + ("tg", "tajik"), + ("sd", "sindhi"), + ("gu", "gujarati"), + ("am", "amharic"), + ("yi", "yiddish"), + ("lo", "lao"), + ("uz", "uzbek"), + ("fo", "faroese"), + ("ht", "haitian creole"), + ("ps", "pashto"), + ("tk", "turkmen"), + ("nn", "nynorsk"), + ("mt", "maltese"), + ("sa", "sanskrit"), + ("lb", "luxembourgish"), + ("my", "myanmar"), + ("bo", "tibetan"), + ("tl", "tagalog"), + ("mg", "malagasy"), + ("as", "assamese"), + ("tt", "tatar"), + ("haw", "hawaiian"), + ("ln", "lingala"), + ("ha", "hausa"), + ("ba", "bashkir"), + ("jw", "javanese"), + ("su", "sundanese"), +]; diff --git a/candle-wasm-examples/whisper/src/lib.rs b/candle-wasm-examples/whisper/src/lib.rs index 141714f5..f1832012 100644 --- a/candle-wasm-examples/whisper/src/lib.rs +++ b/candle-wasm-examples/whisper/src/lib.rs @@ -4,14 +4,14 @@ struct Timer { label: &'static str, } -impl Timer { - fn new(label: &'static str) -> Self { - if WITH_TIMER { - web_sys::console::time_with_label(label); - } - Self { label } - } -} +// impl Timer { +// fn new(label: &'static str) -> Self { +// if WITH_TIMER { +// web_sys::console::time_with_label(label); +// } +// Self { label } +// } +// } impl Drop for Timer { fn drop(&mut self) { @@ -23,7 +23,7 @@ impl Drop for Timer { mod app; mod audio; -mod model; +pub mod languages; pub mod worker; pub use app::App; pub use worker::Worker; diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs deleted file mode 100644 index 8574124b..00000000 --- a/candle-wasm-examples/whisper/src/model.rs +++ /dev/null @@ -1,417 +0,0 @@ -// We use anyhow rather than candle errors as it provides better support for getting the backtrace -// back when using RUST_LIB_BACKTRACE=1. -use anyhow::Result; -use candle::{Device, Tensor}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; -use serde::Deserialize; - -// The names in comments correspond to the original implementation: -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - pub num_mel_bins: usize, // n_mels - pub max_source_positions: usize, // n_audio_ctx - pub d_model: usize, // n_audio_state - pub encoder_attention_heads: usize, // n_audio_head - pub encoder_layers: usize, // n_audio_layer - pub vocab_size: usize, // n_vocab - pub max_target_positions: usize, // n_text_ctx - // pub n_text_state: usize, - pub decoder_attention_heads: usize, // n_text_head - pub decoder_layers: usize, // n_text_layer -} - -impl Config { - pub fn tiny_en() -> Self { - Self { - num_mel_bins: 80, - vocab_size: 51864, - max_source_positions: 1500, - d_model: 384, - encoder_attention_heads: 6, - encoder_layers: 4, - max_target_positions: 448, - // n_text_state: 384, - decoder_attention_heads: 6, - decoder_layers: 4, - } - } -} - -// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm -// specific monitoring. -#[derive(Debug)] -struct Linear { - weight: Tensor, - bias: Option, -} - -impl Linear { - fn new(weight: Tensor, bias: Option) -> Self { - Self { weight, bias } - } - - fn forward(&self, x: &Tensor) -> candle::Result { - let _timer = crate::Timer::new("Linear::forward"); - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, - _ => self.weight.t()?, - }; - let x = { - let _timer = crate::Timer::new("Linear::matmul"); - x.matmul(&w)? - }; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - let bias = vb.get(size2, "bias")?; - Ok(Linear::new(weight, Some(bias))) -} - -fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) -} - -fn conv1d( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result { - let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; - let bias = vb.get(out_channels, "bias")?; - Ok(Conv1d::new(weight, Some(bias), config)) -} - -fn layer_norm(size: usize, vb: VarBuilder) -> Result { - let weight = vb.get(size, "weight")?; - let bias = vb.get(size, "bias")?; - Ok(LayerNorm::new(weight, bias, 1e-5)) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 -struct MultiHeadAttention { - query: Linear, - key: Linear, - value: Linear, - out: Linear, - n_head: usize, - kv_cache: Option<(Tensor, Tensor)>, -} - -impl MultiHeadAttention { - fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { - let query = linear(n_state, n_state, vb.pp("q_proj"))?; - let value = linear(n_state, n_state, vb.pp("v_proj"))?; - let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; - let out = linear(n_state, n_state, vb.pp("out_proj"))?; - Ok(Self { - query, - key, - value, - out, - n_head, - kv_cache: None, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_cache: bool, - ) -> Result { - let _timer = crate::Timer::new("MultiHeadAttention::forward"); - let q = self.query.forward(x)?; - let (k, v) = match xa { - None => { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - (k, v) - } - Some(x) => { - if flush_cache { - self.kv_cache = None; - } - if let Some((k, v)) = &self.kv_cache { - (k.clone(), v.clone()) - } else { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - self.kv_cache = Some((k.clone(), v.clone())); - (k, v) - } - } - }; - let wv = self.qkv_attention(&q, &k, &v, mask)?; - let out = self.out.forward(&wv)?; - Ok(out) - } - - fn reshape_head(&self, x: &Tensor) -> Result { - let (n_batch, n_ctx, n_state) = x.dims3()?; - let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; - Ok(x.reshape(target_dims)?.transpose(1, 2)?) - } - - fn qkv_attention( - &self, - q: &Tensor, - k: &Tensor, - v: &Tensor, - mask: Option<&Tensor>, - ) -> Result { - let (_, n_ctx, n_state) = q.dims3()?; - let scale = ((n_state / self.n_head) as f64).powf(-0.25); - let q = (self.reshape_head(q)? * scale)?; - let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; - let v = self.reshape_head(v)?.contiguous()?; - let mut qk = { - let _timer = crate::Timer::new("qk::matmul"); - q.matmul(&k)? - }; - if let Some(mask) = mask { - let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; - qk = qk.broadcast_add(&mask)? - } - let w = { - let _timer = crate::Timer::new("qk::softmax"); - candle_nn::ops::softmax(&qk, candle::D::Minus1)? - }; - let wv = { - let _timer = crate::Timer::new("wv::matmul"); - w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)? - }; - Ok(wv) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 -struct ResidualAttentionBlock { - attn: MultiHeadAttention, - attn_ln: LayerNorm, - cross_attn: Option<(MultiHeadAttention, LayerNorm)>, - mlp_linear1: Linear, - mlp_linear2: Linear, - mlp_ln: LayerNorm, -} - -impl ResidualAttentionBlock { - fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result { - let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; - let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; - let cross_attn = if ca { - let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; - let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; - Some((cross_attn, cross_attn_ln)) - } else { - None - }; - let n_mlp = n_state * 4; - let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; - let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; - let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; - Ok(Self { - attn, - attn_ln, - cross_attn, - mlp_linear1, - mlp_linear2, - mlp_ln, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_kv_cache: bool, - ) -> Result { - let _timer = crate::Timer::new("ResidualAttentionBlock::forward"); - let attn = self - .attn - .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; - let mut x = (x + attn)?; - if let Some((attn, ln)) = &mut self.cross_attn { - x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; - } - let mlp = self.mlp_linear2.forward( - &self - .mlp_linear1 - .forward(&self.mlp_ln.forward(&x)?)? - .gelu()?, - )?; - Ok((x + mlp)?) - } -} - -fn sinusoids(length: usize, channels: usize) -> Result { - let max_timescale = 10000f32; - let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; - let inv_timescales: Vec<_> = (0..channels / 2) - .map(|i| (i as f32 * (-log_timescale_increment)).exp()) - .collect(); - let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let arange = Tensor::arange(0, length as u32, &Device::Cpu)? - .to_dtype(candle::DType::F32)? - .unsqueeze(1)?; - let sh = (length, channels / 2); - let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; - let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; - Ok(sincos) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 -pub struct AudioEncoder { - conv1: Conv1d, - conv2: Conv1d, - positional_embedding: Tensor, - blocks: Vec, - ln_post: LayerNorm, -} - -impl AudioEncoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result { - let n_state = cfg.d_model; - let n_head = cfg.encoder_attention_heads; - let n_ctx = cfg.max_source_positions; - let cfg1 = Conv1dConfig { - padding: 1, - stride: 1, - groups: 1, - dilation: 1, - }; - let cfg2 = Conv1dConfig { - padding: 1, - stride: 2, - groups: 1, - dilation: 1, - }; - let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; - let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; - let blocks = (0..cfg.encoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) - }) - .collect::>>()?; - let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; - Ok(Self { - conv1, - conv2, - positional_embedding, - blocks, - ln_post, - }) - } - pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result { - let _timer = crate::Timer::new("AudioEncoder::forward"); - let x = { - let _timer = crate::Timer::new("conv1::forward"); - self.conv1.forward(x)?.gelu()? - }; - let x = { - let _timer = crate::Timer::new("conv2::forward"); - self.conv2.forward(&x)?.gelu()? - }; - let x = x.transpose(1, 2)?; - let (_bsize, seq_len, _hidden) = x.dims3()?; - let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; - let mut x = x.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, None, None, flush_kv_cache)? - } - let x = self.ln_post.forward(&x)?; - Ok(x) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 -pub struct TextDecoder { - token_embedding: Embedding, - positional_embedding: Tensor, - blocks: Vec, - ln: LayerNorm, - mask: Tensor, -} - -impl TextDecoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result { - let _timer = crate::Timer::new("TextDecoder::forward"); - let n_state = cfg.d_model; - let n_head = cfg.decoder_attention_heads; - let n_ctx = cfg.max_target_positions; - let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; - let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; - let blocks = (0..cfg.decoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) - }) - .collect::>>()?; - let ln = layer_norm(n_state, vb.pp("layer_norm"))?; - let mask: Vec<_> = (0..n_ctx) - .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) - .collect(); - let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; - - Ok(Self { - token_embedding, - positional_embedding, - blocks, - ln, - mask, - }) - } - - pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result { - let x_dims = x.dims(); - let last = x_dims[x_dims.len() - 1]; - let token_embedding = self.token_embedding.forward(x)?; - let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; - let mut x = token_embedding.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; - } - let x = self.ln.forward(&x)?; - let w = self - .token_embedding - .embeddings() - .broadcast_left(x_dims[0])?; - let logits = x.matmul(&w.t()?)?; - Ok(logits) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 -pub struct Whisper { - pub encoder: AudioEncoder, - pub decoder: TextDecoder, - pub config: Config, -} - -impl Whisper { - pub fn load(vb: &VarBuilder, config: Config) -> Result { - let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; - let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; - Ok(Self { - encoder, - decoder, - config, - }) - } -} diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index 6ea0954c..85272fe7 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -1,7 +1,8 @@ -use crate::model::{Config, Whisper}; +use crate::languages::LANGUAGES; use anyhow::Error as E; -use candle::{safetensors::Load, DType, Device, Tensor}; +use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D}; use candle_nn::{ops::softmax, VarBuilder}; +pub use candle_transformers::models::whisper::{self as m, Config}; use rand::{distributions::Distribution, rngs::StdRng, SeedableRng}; use serde::{Deserialize, Serialize}; use tokenizers::Tokenizer; @@ -25,38 +26,46 @@ macro_rules! console_log { pub const DTYPE: DType = DType::F32; -// Audio parameters. -pub const SAMPLE_RATE: usize = 16000; -pub const N_FFT: usize = 400; -pub const N_MELS: usize = 80; -pub const HOP_LENGTH: usize = 160; -pub const CHUNK_LENGTH: usize = 30; -pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk -pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input +pub enum Model { + Normal(m::model::Whisper), + Quantized(m::quantized_model::Whisper), +} -pub const NO_SPEECH_THRESHOLD: f64 = 0.6; -pub const LOGPROB_THRESHOLD: f64 = -1.0; -pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; -pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; +// Maybe we should use some traits rather than doing the dispatch for all these. +impl Model { + pub fn config(&self) -> &Config { + match self { + Self::Normal(m) => &m.config, + Self::Quantized(m) => &m.config, + } + } -// Tokenizer dependent bits. -const SOT_TOKEN: &str = "<|startoftranscript|>"; -const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; -const TRANSLATE_TOKEN: &str = "<|translate|>"; -const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; -const EOT_TOKEN: &str = "<|endoftext|>"; -const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; + pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result { + match self { + Self::Normal(m) => m.encoder.forward(x, flush), + Self::Quantized(m) => m.encoder.forward(x, flush), + } + } -// From the _get_suppress_tokens function + 50362 (no timestamp) -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 -pub const SUPPRESS_TOKENS: [u32; 91] = [ - 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, - 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, - 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, - 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, - 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, - 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, -]; + pub fn decoder_forward( + &mut self, + x: &Tensor, + xa: &Tensor, + flush: bool, + ) -> candle::Result { + match self { + Self::Normal(m) => m.decoder.forward(x, xa, flush), + Self::Quantized(m) => m.decoder.forward(x, xa, flush), + } + } + + pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result { + match self { + Self::Normal(m) => m.decoder.final_linear(x), + Self::Quantized(m) => m.decoder.final_linear(x), + } + } +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DecodingResult { @@ -77,8 +86,13 @@ pub struct Segment { #[allow(unused)] pub struct Decoder { - model: Whisper, + model: Model, + rng: rand::rngs::StdRng, + task: Option, + language: Option, + is_multilingual: bool, mel_filters: Vec, + timestamps: bool, tokenizer: Tokenizer, suppress_tokens: Tensor, sot_token: u32, @@ -90,32 +104,43 @@ pub struct Decoder { } impl Decoder { + #[allow(clippy::too_many_arguments)] fn new( - model: Whisper, + model: Model, tokenizer: Tokenizer, mel_filters: Vec, device: &Device, + task: Option, + language: Option, + is_multilingual: bool, + timestamps: bool, ) -> anyhow::Result { - let suppress_tokens: Vec = (0..model.config.vocab_size as u32) + let suppress_tokens: Vec = (0..model.config().vocab_size as u32) .map(|i| { - if SUPPRESS_TOKENS.contains(&i) { + if model.config().suppress_tokens.contains(&i) { f32::NEG_INFINITY } else { 0f32 } }) .collect(); - let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; + let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; - let sot_token = token_id(&tokenizer, SOT_TOKEN)?; - let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; - let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; - let eot_token = token_id(&tokenizer, EOT_TOKEN)?; - let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; + let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; + let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; + let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?; + let seed = 299792458; Ok(Self { model, - mel_filters, + rng: StdRng::seed_from_u64(seed), tokenizer, + mel_filters, + task, + timestamps, + language, + is_multilingual, suppress_tokens, sot_token, transcribe_token, @@ -126,40 +151,73 @@ impl Decoder { }) } - fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result { + fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result { let model = &mut self.model; - let audio_features = model.encoder.forward(mel, true)?; - console_log!("audio features: {:?}", audio_features.dims()); - let sample_len = model.config.max_target_positions / 2; + let language_token = match (self.is_multilingual, &self.language) { + (true, None) => Some(detect_language(model, &self.tokenizer, mel)?), + (false, None) => None, + (true, Some(language)) => { + match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) { + Ok(token_id) => Some(token_id), + Err(_) => anyhow::bail!("language {language} is not supported"), + } + } + (false, Some(_)) => { + anyhow::bail!("a language cannot be set for non-multilingual models") + } + }; + + let audio_features = model.encoder_forward(mel, true)?; + println!("audio features: {:?}", audio_features.dims()); + let sample_len = model.config().max_target_positions / 2; let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; - let mut tokens = vec![self.sot_token, self.transcribe_token]; + let mut tokens = vec![self.sot_token]; + if let Some(language_token) = language_token { + tokens.push(language_token); + } + match self.task { + None | Some(Task::Transcribe) => tokens.push(self.transcribe_token), + Some(Task::Translate) => tokens.push(self.translate_token), + } + if !self.timestamps { + tokens.push(self.no_timestamps_token); + } for i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; // The model expects a batch dim but this inference loop does not handle // it so we add it at this point. let tokens_t = tokens_t.unsqueeze(0)?; - let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?; - let logits = logits.squeeze(0)?; + let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?; // Extract the no speech probability on the first iteration by looking at the first // token logits and the probability for the according token. if i == 0 { - no_speech_prob = softmax(&logits.get(0)?, 0)? - .get(self.no_speech_token as usize)? + let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; + no_speech_prob = softmax(&logits, 0)? + .i(self.no_speech_token as usize)? .to_scalar::()? as f64; } - let (seq_len, _) = logits.dims2()?; - let logits = logits - .get(seq_len - 1)? - .broadcast_add(&self.suppress_tokens)?; + let (_, seq_len, _) = ys.dims3()?; + let logits = model + .decoder_final_linear(&ys.i((..1, seq_len - 1..))?)? + .i(0)? + .i(0)?; + // TODO: Besides suppress tokens, we should apply the heuristics from + // ApplyTimestampRules, i.e.: + // - Timestamps come in pairs, except before EOT. + // - Timestamps should be non-decreasing. + // - If the sum of the probabilities of timestamps is higher than any other tokens, + // only consider timestamps when sampling. + // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439 + let logits = logits.broadcast_add(&self.suppress_tokens)?; let next_token = if t > 0f64 { let prs = softmax(&(&logits / t)?, 0)?; let logits_v: Vec = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; - distr.sample(rng) as u32 + distr.sample(&mut self.rng) as u32 } else { let logits_v: Vec = logits.to_vec1()?; logits_v @@ -171,9 +229,9 @@ impl Decoder { }; tokens.push(next_token); let prob = softmax(&logits, candle::D::Minus1)? - .get(next_token as usize)? + .i(next_token as usize)? .to_scalar::()? as f64; - if next_token == self.eot_token || tokens.len() > model.config.max_target_positions { + if next_token == self.eot_token || tokens.len() > model.config().max_target_positions { break; } sum_logprob += prob.ln(); @@ -191,22 +249,18 @@ impl Decoder { }) } - fn decode_with_fallback( - &mut self, - segment: &Tensor, - rng: &mut StdRng, - ) -> anyhow::Result { - for (i, &t) in TEMPERATURES.iter().enumerate() { - let dr: Result = self.decode(segment, t, rng); - if i == TEMPERATURES.len() - 1 { + fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result { + for (i, &t) in m::TEMPERATURES.iter().enumerate() { + let dr: Result = self.decode(segment, t); + if i == m::TEMPERATURES.len() - 1 { return dr; } // On errors, we try again with a different temperature. match dr { Ok(dr) => { - let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD - || dr.avg_logprob < LOGPROB_THRESHOLD; - if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < m::LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD { return Ok(dr); } } @@ -219,18 +273,17 @@ impl Decoder { } fn run(&mut self, mel: &Tensor) -> anyhow::Result> { - let mut rng = StdRng::seed_from_u64(299792458); let (_, _, content_frames) = mel.dims3()?; let mut seek = 0; let mut segments = vec![]; while seek < content_frames { - let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let segment_size = usize::min(content_frames - seek, N_FRAMES); + let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, m::N_FRAMES); let mel_segment = mel.narrow(2, seek, segment_size)?; - let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let dr = self.decode_with_fallback(&mel_segment, &mut rng)?; + let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let dr = self.decode_with_fallback(&mel_segment)?; seek += segment_size; - if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD { console_log!("no speech detected, skipping {seek} {dr:?}"); continue; } @@ -247,17 +300,39 @@ impl Decoder { pub fn load(md: ModelData) -> anyhow::Result { let device = Device::Cpu; - let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?; + let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?; let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?; let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; console_log!("loaded mel filters {:?}", mel_filters.shape()); let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; - let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?; - let config = Config::tiny_en(); - let whisper = Whisper::load(&vb, config)?; + let config: Config = serde_json::from_slice(&md.config)?; + let model = if md.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( + &md.weights, + )?; + Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) + } else { + let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?; + Model::Normal(m::model::Whisper::load(&vb, config)?) + }; console_log!("done loading model"); - let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?; + + let task = match md.task.as_deref() { + Some("translate") => Some(Task::Translate), + _ => Some(Task::Transcribe), + }; + + let decoder = Self::new( + model, + tokenizer, + mel_filters, + &device, + task, + md.language, + md.is_multilingual, + md.timestamps, + )?; Ok(decoder) } @@ -266,8 +341,8 @@ impl Decoder { let mut wav_input = std::io::Cursor::new(wav_input); let (header, data) = wav::read(&mut wav_input)?; console_log!("loaded wav data: {header:?}"); - if header.sampling_rate != SAMPLE_RATE as u32 { - anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate"); + if header.sampling_rate != m::SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE); } let data = data.as_sixteen().expect("expected 16 bit wav file"); let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] @@ -277,27 +352,74 @@ impl Decoder { console_log!("pcm data loaded {}", pcm_data.len()); let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?; let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; + let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; console_log!("loaded mel: {:?}", mel.dims()); let segments = self.run(&mel)?; Ok(segments) } } +/// Returns the token id for the selected language. +pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result { + console_log!("detecting language"); + let (_bsize, _, seq_len) = mel.dims3()?; + let mel = mel.narrow( + 2, + 0, + usize::min(seq_len, model.config().max_source_positions), + )?; + let device = mel.device(); + + let language_token_ids = LANGUAGES + .iter() + .map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>"))) + .map(|e| e.map_err(E::msg)) + .collect::, E>>()?; + + let sot_token = token_id(tokenizer, m::SOT_TOKEN)?; + let audio_features = model.encoder_forward(&mel, true)?; + let tokens = Tensor::new(&[[sot_token]], device)?; + let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; + let ys = model.decoder_forward(&tokens, &audio_features, true)?; + let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?; + let logits = logits.index_select(&language_token_ids, 0)?; + let probs = candle_nn::ops::softmax(&logits, D::Minus1)?; + let probs = probs.to_vec1::()?; + let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::>(); + probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for ((_, language), p) in probs.iter().take(5) { + println!("{language}: {p}") + } + let token = &format!("<|{}|>", probs[0].0 .0); + let language = token_id(tokenizer, token)?; + console_log!("detected language: {language} {token}"); + Ok(language) +} pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result { match tokenizer.token_to_id(token) { None => candle::bail!("no token-id for {token}"), Some(id) => Ok(id), } } +#[derive(Serialize, Deserialize, Clone, Copy, Debug)] +pub enum Task { + Transcribe, + Translate, +} // Communication to the worker happens through bincode, the model weights and configs are fetched // on the main thread and transfered via the following structure. #[derive(Serialize, Deserialize)] pub struct ModelData { + pub weights: Vec, pub tokenizer: Vec, pub mel_filters: Vec, - pub weights: Vec, + pub config: Vec, + pub quantized: bool, + pub timestamps: bool, + pub is_multilingual: bool, + pub language: Option, + pub task: Option, } pub struct Worker { diff --git a/candle-wasm-examples/whisper/whisperWorker.js b/candle-wasm-examples/whisper/whisperWorker.js index d2ad8e0b..bd44f62c 100644 --- a/candle-wasm-examples/whisper/whisperWorker.js +++ b/candle-wasm-examples/whisper/whisperWorker.js @@ -17,23 +17,46 @@ class Whisper { static instance = {}; // Retrieve the Whisper model. When called for the first time, // this will load the model and save it for future use. - static async getInstance(weightsURL, modelID, tokenizerURL, mel_filtersURL) { + static async getInstance(params) { + const { + weightsURL, + modelID, + tokenizerURL, + mel_filtersURL, + configURL, + quantized, + is_multilingual, + timestamps, + task, + language, + } = params; // load individual modelID only once if (!this.instance[modelID]) { await init(); self.postMessage({ status: "loading", message: "Loading Model" }); - const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] = - await Promise.all([ - fetchArrayBuffer(weightsURL), - fetchArrayBuffer(tokenizerURL), - fetchArrayBuffer(mel_filtersURL), - ]); + const [ + weightsArrayU8, + tokenizerArrayU8, + mel_filtersArrayU8, + configArrayU8, + ] = await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + fetchArrayBuffer(mel_filtersURL), + fetchArrayBuffer(configURL), + ]); this.instance[modelID] = new Decoder( weightsArrayU8, tokenizerArrayU8, - mel_filtersArrayU8 + mel_filtersArrayU8, + configArrayU8, + quantized, + is_multilingual, + timestamps, + task, + language ); } else { self.postMessage({ status: "loading", message: "Model Already Loaded" }); @@ -43,17 +66,37 @@ class Whisper { } self.addEventListener("message", async (event) => { - const { weightsURL, modelID, tokenizerURL, mel_filtersURL, audioURL } = - event.data; + const { + weightsURL, + modelID, + tokenizerURL, + configURL, + mel_filtersURL, + audioURL, + } = event.data; try { self.postMessage({ status: "decoding", message: "Starting Decoder" }); - - const decoder = await Whisper.getInstance( + let quantized = false; + if (modelID.includes("quantized")) { + quantized = true; + } + let is_multilingual = false; + if (modelID.includes("multilingual")) { + is_multilingual = true; + } + let timestamps = true; + const decoder = await Whisper.getInstance({ weightsURL, modelID, tokenizerURL, - mel_filtersURL - ); + mel_filtersURL, + configURL, + quantized, + is_multilingual, + timestamps, + task: null, + language: null, + }); self.postMessage({ status: "decoding", message: "Loading Audio" }); const audioArrayU8 = await fetchArrayBuffer(audioURL);