Whisper quantized wasm (#1028)

* [Whisper] Update to use quantized model

* [whisper] add language detection

* [whisper] change assets location

* [whisper] adapt js example with quantized models

* [whisper] better task parsing

* [whisper] minor fixes
This commit is contained in:
Radamés Ajna
2023-10-04 12:22:57 -07:00
committed by GitHub
parent c18a856e76
commit 27e70a5093
13 changed files with 540 additions and 596 deletions

View File

@ -11,6 +11,7 @@ license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -10,19 +10,31 @@ From the `candle-wasm-examples/whisper` directory run:
Download assets:
```bash
# Model and tokenizer
# mel filters
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tiny.en.safetensors
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tokenizer.en.json
# Model and tokenizer tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en
# model and tokenizer tiny multilanguage
wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny
#quantized
wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized
# Audio samples
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -O gb0.wav
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -O a13.wav
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -O gb1.wav
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -O hp0.wav
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -O jfk.wav
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -O mm0.wav
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios
```

View File

@ -3,22 +3,38 @@
<head>
<meta charset="utf-8" />
<title>Welcome to Candle!</title>
<link data-trunk rel="copy-file" href="jfk.wav" />
<link data-trunk rel="copy-file" href="mm0.wav" />
<link data-trunk rel="copy-file" href="a13.wav" />
<link data-trunk rel="copy-file" href="gb0.wav" />
<link data-trunk rel="copy-file" href="gb1.wav" />
<link data-trunk rel="copy-file" href="hp0.wav" />
<link data-trunk rel="copy-file" href="tokenizer.en.json" />
<link data-trunk rel="copy-file" href="mel_filters.safetensors" />
<link data-trunk rel="copy-file" href="tiny.en.safetensors" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
<!-- samples -->
<link data-trunk rel="copy-dir" href="audios" />
<!-- tiny.en -->
<link data-trunk rel="copy-dir" href="whisper-tiny.en" />
<!-- tiny -->
<link data-trunk rel="copy-dir" href="whisper-tiny" />
<!-- quantized -->
<link data-trunk rel="copy-dir" href="quantized" />
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
<link
data-trunk
rel="rust"
href="Cargo.toml"
data-bin="app"
data-type="main" />
<link
data-trunk
rel="rust"
href="Cargo.toml"
data-bin="worker"
data-type="worker" />
<link
rel="stylesheet"
href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css" />
</head>
<body></body>
</html>

View File

@ -26,9 +26,30 @@
// models base url
const MODELS = {
tiny_multilingual: {
base_url: "https://huggingface.co/openai/whisper-tiny/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
},
tiny_en: {
base_url:
"https://huggingface.co/openai/whisper-tiny.en/resolve/refs%2Fpr%2F17/",
"https://huggingface.co/openai/whisper-tiny.en/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
},
tiny_quantized_multilingual_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny.json",
config: "config-tiny.json",
},
tiny_en_quantized_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny-en.json",
config: "config-tiny-en.json",
},
};
const whisperWorker = new Worker("./whisperWorker.js", {
@ -39,6 +60,7 @@
weightsURL, // URL to the weights file
modelID, // model ID
tokenizerURL, // URL to the tokenizer file
configURL, // model config URL
mel_filtersURL, // URL to the mel filters file
audioURL, // URL to the audio file
updateStatus // function to update the status
@ -48,6 +70,7 @@
weightsURL,
modelID,
tokenizerURL,
configURL,
mel_filtersURL,
audioURL,
});
@ -128,13 +151,16 @@
return;
}
const modelID = document.querySelector("#model").value;
const modelURL = MODELS[modelID].base_url + "model.safetensors";
const tokenizerURL = MODELS[modelID].base_url + "tokenizer.json";
const model = MODELS[modelID];
const modelURL = model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const configURL = model.base_url + model.config;
classifyAudio(
modelURL,
modelID,
tokenizerURL,
configURL,
"mel_filters.safetensors",
audioURL,
updateStatus
@ -178,8 +204,7 @@
<a
href="https://huggingface.co/openai/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>
class="underline hover:text-blue-500 hover:no-underline">
OpenAI Whisper models
</a>
and WASM runtime built with
@ -196,37 +221,38 @@
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light"
>
class="border-2 border-gray-500 rounded-md font-light">
<option value="tiny_multilingual" selected>tiny (151 MB)</option>
<option value="tiny_en" selected>tiny.en (151 MB)</option>
<option value="tiny_quantized_multilingual_q80">
tiny quantized q80 (41.5 MB)
</option>
<option value="tiny_en_quantized_q80">
tiny.en quantized q80 (41.8 MB)
</option>
</select>
</div>
<!-- drag and drop area -->
<div class="relative">
<div
id="drop-area"
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden"
>
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden">
<div
class="flex flex-col items-center justify-center space-y-1 text-center"
>
class="flex flex-col items-center justify-center space-y-1 text-center">
<svg
width="25"
height="25"
viewBox="0 0 25 25"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
xmlns="http://www.w3.org/2000/svg">
<path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
fill="#000"
/>
fill="#000" />
</svg>
<div class="flex text-sm text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
>
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
<span>Drag and drop your audio here</span>
<span class="block text-xs">or</span>
<span class="block text-xs">Click to upload</span>
@ -237,15 +263,13 @@
name="file-upload"
type="file"
accept="audio/*"
class="sr-only"
/>
class="sr-only" />
</div>
<audio
id="audio"
hidden
controls
class="w-full p-2 select-none"
></audio>
class="w-full p-2 select-none"></audio>
</div>
</div>
<div>
@ -253,43 +277,37 @@
<h3 class="font-medium">Examples:</h3>
<button
data-value="samples_jfk.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
>
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>jfk.wav</span>
<span class="text-xs block"> (352 kB)</span>
</button>
<button
data-value="samples_a13.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
>
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>a13.wav</span>
<span class="text-xs block"> (960 kB)</span>
</button>
<button
data-value="samples_mm0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
>
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>mm0.wav</span>
<span class="text-xs block new"> (957 kB)</span>
</button>
<button
data-value="samples_gb0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
>
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>gb0.wav </span>
<span class="text-xs block">(4.08 MB)</span>
</button>
<button
data-value="samples_gb1.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
>
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>gb1.wav </span>
<span class="text-xs block">(6.36 MB)</span>
</button>
<button
data-value="samples_hp0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
>
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
<span>hp0.wav </span>
<span class="text-xs block">(8.75 MB)</span>
</button>
@ -300,16 +318,14 @@
<button
id="detect"
disabled
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
>
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
Transcribe Audio
</button>
</div>
<div>
<h3 class="font-medium">Transcription:</h3>
<div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
>
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
<p hidden id="output-generation" class="grid-rows-2"></p>
<span id="output-status" class="m-auto font-light"
>No transcription results yet</span

View File

@ -7,7 +7,12 @@ use yew::{html, Component, Context, Html};
use yew_agent::{Bridge, Bridged};
const SAMPLE_NAMES: [&str; 6] = [
"jfk.wav", "a13.wav", "gb0.wav", "gb1.wav", "hp0.wav", "mm0.wav",
"audios/samples_jfk.wav",
"audios/samples_a13.wav",
"audios/samples_gb0.wav",
"audios/samples_gb1.wav",
"audios/samples_hp0.wav",
"audios/samples_mm0.wav",
];
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
@ -54,14 +59,46 @@ pub struct App {
}
async fn model_data_load() -> Result<ModelData, JsValue> {
let tokenizer = fetch_url("tokenizer.en.json").await?;
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let weights = fetch_url("tiny.en.safetensors").await?;
let quantized = false;
let is_multilingual = false;
let (tokenizer, mel_filters, weights, config) = if quantized {
console_log!("loading quantized weights");
let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?;
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?;
let config = fetch_url("quantized/config-tiny-en.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
console_log!("loading float weights");
if is_multilingual {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny/model.safetensors").await?;
let config = fetch_url("whisper-tiny/config.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny.en/model.safetensors").await?;
let config = fetch_url("whisper-tiny.en/config.json").await?;
(tokenizer, mel_filters, weights, config)
}
};
let timestamps = true;
let _task = Some("transcribe".to_string());
console_log!("{}", weights.len());
Ok(ModelData {
tokenizer,
mel_filters,
weights,
config,
quantized,
timestamps,
task: None,
is_multilingual,
language: None,
})
}

View File

@ -168,7 +168,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
let n_len = samples.len() / fft_step;
// pad audio with at least one extra chunk of zeros
let pad = 100 * worker::CHUNK_LENGTH / 2;
let pad = 100 * worker::m::CHUNK_LENGTH / 2;
let n_len = if n_len % pad != 0 {
(n_len / pad + 1) * pad
} else {
@ -206,9 +206,9 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
let mel = log_mel_spectrogram_(
samples,
filters,
worker::N_FFT,
worker::HOP_LENGTH,
worker::N_MELS,
worker::m::N_FFT,
worker::m::HOP_LENGTH,
worker::m::N_MELS,
false,
);
Ok(mel)

View File

@ -9,15 +9,28 @@ pub struct Decoder {
#[wasm_bindgen]
impl Decoder {
#[wasm_bindgen(constructor)]
#[allow(clippy::too_many_arguments)]
pub fn new(
weights: Vec<u8>,
tokenizer: Vec<u8>,
mel_filters: Vec<u8>,
config: Vec<u8>,
quantized: bool,
is_multilingual: bool,
timestamps: bool,
task: Option<String>,
language: Option<String>,
) -> Result<Decoder, JsError> {
let decoder = D::load(ModelData {
tokenizer,
mel_filters,
config,
quantized,
weights,
is_multilingual,
timestamps,
task,
language,
});
match decoder {
@ -32,7 +45,6 @@ impl Decoder {
.decoder
.convert_and_run(&wav_input)
.map_err(|e| JsError::new(&e.to_string()))?;
let json = serde_json::to_string(&segments)?;
Ok(json)
}

View File

@ -0,0 +1,101 @@
pub const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];

View File

@ -4,14 +4,14 @@ struct Timer {
label: &'static str,
}
impl Timer {
fn new(label: &'static str) -> Self {
if WITH_TIMER {
web_sys::console::time_with_label(label);
}
Self { label }
}
}
// impl Timer {
// fn new(label: &'static str) -> Self {
// if WITH_TIMER {
// web_sys::console::time_with_label(label);
// }
// Self { label }
// }
// }
impl Drop for Timer {
fn drop(&mut self) {
@ -23,7 +23,7 @@ impl Drop for Timer {
mod app;
mod audio;
mod model;
pub mod languages;
pub mod worker;
pub use app::App;
pub use worker::Worker;

View File

@ -1,417 +0,0 @@
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_mel_bins: usize, // n_mels
pub max_source_positions: usize, // n_audio_ctx
pub d_model: usize, // n_audio_state
pub encoder_attention_heads: usize, // n_audio_head
pub encoder_layers: usize, // n_audio_layer
pub vocab_size: usize, // n_vocab
pub max_target_positions: usize, // n_text_ctx
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
}
impl Config {
pub fn tiny_en() -> Self {
Self {
num_mel_bins: 80,
vocab_size: 51864,
max_source_positions: 1500,
d_model: 384,
encoder_attention_heads: 6,
encoder_layers: 4,
max_target_positions: 448,
// n_text_state: 384,
decoder_attention_heads: 6,
decoder_layers: 4,
}
}
}
// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm
// specific monitoring.
#[derive(Debug)]
struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _timer = crate::Timer::new("Linear::forward");
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = {
let _timer = crate::Timer::new("Linear::matmul");
x.matmul(&w)?
};
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn conv1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
let bias = vb.get(out_channels, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
Ok(LayerNorm::new(weight, bias, 1e-5))
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
struct MultiHeadAttention {
query: Linear,
key: Linear,
value: Linear,
out: Linear,
n_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
Ok(Self {
query,
key,
value,
out,
n_head,
kv_cache: None,
})
}
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_cache: bool,
) -> Result<Tensor> {
let _timer = crate::Timer::new("MultiHeadAttention::forward");
let q = self.query.forward(x)?;
let (k, v) = match xa {
None => {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
(k, v)
}
Some(x) => {
if flush_cache {
self.kv_cache = None;
}
if let Some((k, v)) = &self.kv_cache {
(k.clone(), v.clone())
} else {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
self.kv_cache = Some((k.clone(), v.clone()));
(k, v)
}
}
};
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;
Ok(out)
}
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
fn qkv_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let (_, n_ctx, n_state) = q.dims3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let mut qk = {
let _timer = crate::Timer::new("qk::matmul");
q.matmul(&k)?
};
if let Some(mask) = mask {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = {
let _timer = crate::Timer::new("qk::softmax");
candle_nn::ops::softmax(&qk, candle::D::Minus1)?
};
let wv = {
let _timer = crate::Timer::new("wv::matmul");
w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?
};
Ok(wv)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
Some((cross_attn, cross_attn_ln))
} else {
None
};
let n_mlp = n_state * 4;
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
Ok(Self {
attn,
attn_ln,
cross_attn,
mlp_linear1,
mlp_linear2,
mlp_ln,
})
}
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_kv_cache: bool,
) -> Result<Tensor> {
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
let attn = self
.attn
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &mut self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
.mlp_linear1
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
Ok((x + mlp)?)
}
}
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales: Vec<_> = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect();
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
.to_dtype(candle::DType::F32)?
.unsqueeze(1)?;
let sh = (length, channels / 2);
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
Ok(sincos)
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
pub struct AudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
let cfg1 = Conv1dConfig {
padding: 1,
stride: 1,
groups: 1,
dilation: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
Ok(Self {
conv1,
conv2,
positional_embedding,
blocks,
ln_post,
})
}
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let _timer = crate::Timer::new("AudioEncoder::forward");
let x = {
let _timer = crate::Timer::new("conv1::forward");
self.conv1.forward(x)?.gelu()?
};
let x = {
let _timer = crate::Timer::new("conv2::forward");
self.conv2.forward(&x)?.gelu()?
};
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter_mut() {
x = block.forward(&x, None, None, flush_kv_cache)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
pub struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let _timer = crate::Timer::new("TextDecoder::forward");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
let blocks = (0..cfg.decoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
Ok(Self {
token_embedding,
positional_embedding,
blocks,
ln,
mask,
})
}
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter_mut() {
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
}
let x = self.ln.forward(&x)?;
let w = self
.token_embedding
.embeddings()
.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
Ok(logits)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
pub struct Whisper {
pub encoder: AudioEncoder,
pub decoder: TextDecoder,
pub config: Config,
}
impl Whisper {
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
Ok(Self {
encoder,
decoder,
config,
})
}
}

View File

@ -1,7 +1,8 @@
use crate::model::{Config, Whisper};
use crate::languages::LANGUAGES;
use anyhow::Error as E;
use candle::{safetensors::Load, DType, Device, Tensor};
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
use candle_nn::{ops::softmax, VarBuilder};
pub use candle_transformers::models::whisper::{self as m, Config};
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
@ -25,38 +26,46 @@ macro_rules! console_log {
pub const DTYPE: DType = DType::F32;
// Audio parameters.
pub const SAMPLE_RATE: usize = 16000;
pub const N_FFT: usize = 400;
pub const N_MELS: usize = 80;
pub const HOP_LENGTH: usize = 160;
pub const CHUNK_LENGTH: usize = 30;
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
pub enum Model {
Normal(m::model::Whisper),
Quantized(m::quantized_model::Whisper),
}
pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
pub const LOGPROB_THRESHOLD: f64 = -1.0;
pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Maybe we should use some traits rather than doing the dispatch for all these.
impl Model {
pub fn config(&self) -> &Config {
match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}
// Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.encoder.forward(x, flush),
Self::Quantized(m) => m.encoder.forward(x, flush),
}
}
// From the _get_suppress_tokens function + 50362 (no timestamp)
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
pub const SUPPRESS_TOKENS: [u32; 91] = [
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
];
pub fn decoder_forward(
&mut self,
x: &Tensor,
xa: &Tensor,
flush: bool,
) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.forward(x, xa, flush),
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
}
}
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecodingResult {
@ -77,8 +86,13 @@ pub struct Segment {
#[allow(unused)]
pub struct Decoder {
model: Whisper,
model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
mel_filters: Vec<f32>,
timestamps: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
@ -90,32 +104,43 @@ pub struct Decoder {
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Whisper,
model: Model,
tokenizer: Tokenizer,
mel_filters: Vec<f32>,
device: &Device,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
timestamps: bool,
) -> anyhow::Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| {
if SUPPRESS_TOKENS.contains(&i) {
if model.config().suppress_tokens.contains(&i) {
f32::NEG_INFINITY
} else {
0f32
}
})
.collect();
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
let seed = 299792458;
Ok(Self {
model,
mel_filters,
rng: StdRng::seed_from_u64(seed),
tokenizer,
mel_filters,
task,
timestamps,
language,
is_multilingual,
suppress_tokens,
sot_token,
transcribe_token,
@ -126,40 +151,73 @@ impl Decoder {
})
}
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
console_log!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.max_target_positions / 2;
let language_token = match (self.is_multilingual, &self.language) {
(true, None) => Some(detect_language(model, &self.tokenizer, mel)?),
(false, None) => None,
(true, Some(language)) => {
match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
}
}
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
let audio_features = model.encoder_forward(mel, true)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token, self.transcribe_token];
let mut tokens = vec![self.sot_token];
if let Some(language_token) = language_token {
tokens.push(language_token);
}
match self.task {
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
}
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
let logits = logits.squeeze(0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(self.no_speech_token as usize)?
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
let (seq_len, _) = logits.dims2()?;
let logits = logits
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(rng) as u32
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
@ -171,9 +229,9 @@ impl Decoder {
};
tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break;
}
sum_logprob += prob.ln();
@ -191,22 +249,18 @@ impl Decoder {
})
}
fn decode_with_fallback(
&mut self,
segment: &Tensor,
rng: &mut StdRng,
) -> anyhow::Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
if i == TEMPERATURES.len() - 1 {
fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult, _> = self.decode(segment, t);
if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
@ -219,18 +273,17 @@ impl Decoder {
}
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let mut rng = StdRng::seed_from_u64(299792458);
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
console_log!("no speech detected, skipping {seek} {dr:?}");
continue;
}
@ -247,17 +300,39 @@ impl Decoder {
pub fn load(md: ModelData) -> anyhow::Result<Self> {
let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?;
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?;
let config = Config::tiny_en();
let whisper = Whisper::load(&vb, config)?;
let config: Config = serde_json::from_slice(&md.config)?;
let model = if md.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?;
Model::Normal(m::model::Whisper::load(&vb, config)?)
};
console_log!("done loading model");
let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?;
let task = match md.task.as_deref() {
Some("translate") => Some(Task::Translate),
_ => Some(Task::Transcribe),
};
let decoder = Self::new(
model,
tokenizer,
mel_filters,
&device,
task,
md.language,
md.is_multilingual,
md.timestamps,
)?;
Ok(decoder)
}
@ -266,8 +341,8 @@ impl Decoder {
let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?;
console_log!("loaded wav data: {header:?}");
if header.sampling_rate != SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate");
if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@ -277,27 +352,74 @@ impl Decoder {
console_log!("pcm data loaded {}", pcm_data.len());
let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
console_log!("loaded mel: {:?}", mel.dims());
let segments = self.run(&mel)?;
Ok(segments)
}
}
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32, E> {
console_log!("detecting language");
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
.map(|e| e.map_err(E::msg))
.collect::<Result<Vec<_>, E>>()?;
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
let token = &format!("<|{}|>", probs[0].0 .0);
let language = token_id(tokenizer, token)?;
console_log!("detected language: {language} {token}");
Ok(language)
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum Task {
Transcribe,
Translate,
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transfered via the following structure.
#[derive(Serialize, Deserialize)]
pub struct ModelData {
pub weights: Vec<u8>,
pub tokenizer: Vec<u8>,
pub mel_filters: Vec<u8>,
pub weights: Vec<u8>,
pub config: Vec<u8>,
pub quantized: bool,
pub timestamps: bool,
pub is_multilingual: bool,
pub language: Option<String>,
pub task: Option<String>,
}
pub struct Worker {

View File

@ -17,23 +17,46 @@ class Whisper {
static instance = {};
// Retrieve the Whisper model. When called for the first time,
// this will load the model and save it for future use.
static async getInstance(weightsURL, modelID, tokenizerURL, mel_filtersURL) {
static async getInstance(params) {
const {
weightsURL,
modelID,
tokenizerURL,
mel_filtersURL,
configURL,
quantized,
is_multilingual,
timestamps,
task,
language,
} = params;
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(mel_filtersURL),
]);
const [
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8,
configArrayU8,
] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(mel_filtersURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new Decoder(
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8
mel_filtersArrayU8,
configArrayU8,
quantized,
is_multilingual,
timestamps,
task,
language
);
} else {
self.postMessage({ status: "loading", message: "Model Already Loaded" });
@ -43,17 +66,37 @@ class Whisper {
}
self.addEventListener("message", async (event) => {
const { weightsURL, modelID, tokenizerURL, mel_filtersURL, audioURL } =
event.data;
const {
weightsURL,
modelID,
tokenizerURL,
configURL,
mel_filtersURL,
audioURL,
} = event.data;
try {
self.postMessage({ status: "decoding", message: "Starting Decoder" });
const decoder = await Whisper.getInstance(
let quantized = false;
if (modelID.includes("quantized")) {
quantized = true;
}
let is_multilingual = false;
if (modelID.includes("multilingual")) {
is_multilingual = true;
}
let timestamps = true;
const decoder = await Whisper.getInstance({
weightsURL,
modelID,
tokenizerURL,
mel_filtersURL
);
mel_filtersURL,
configURL,
quantized,
is_multilingual,
timestamps,
task: null,
language: null,
});
self.postMessage({ status: "decoding", message: "Loading Audio" });
const audioArrayU8 = await fetchArrayBuffer(audioURL);