Whisper quantized wasm (#1028)

* [Whisper] Update to use quantized model

* [whisper] add language detection

* [whisper] change assets location

* [whisper] adapt js example with quantized models

* [whisper] better task parsing

* [whisper] minor fixes
This commit is contained in:
Radamés Ajna
2023-10-04 12:22:57 -07:00
committed by GitHub
parent c18a856e76
commit 27e70a5093
13 changed files with 540 additions and 596 deletions

7
.gitignore vendored
View File

@ -29,9 +29,10 @@ trace-*.json
candle-wasm-examples/*/build candle-wasm-examples/*/build
candle-wasm-examples/*/*.bin candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.jpeg candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/*.wav candle-wasm-examples/*/audios/*.wav
candle-wasm-examples/*/*.safetensors candle-wasm-examples/**/*.safetensors
candle-wasm-examples/**/*.gguf
candle-wasm-examples/*/package-lock.json candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store .DS_Store
.idea/* .idea/*

View File

@ -11,6 +11,7 @@ license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" } candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" } candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -10,19 +10,31 @@ From the `candle-wasm-examples/whisper` directory run:
Download assets: Download assets:
```bash ```bash
# Model and tokenizer # mel filters
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tiny.en.safetensors # Model and tokenizer tiny.en
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tokenizer.en.json wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en
# model and tokenizer tiny multilanguage
wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny
#quantized
wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized
# Audio samples # Audio samples
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -O gb0.wav wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -O a13.wav wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -O gb1.wav wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -O hp0.wav wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -O jfk.wav wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -O mm0.wav wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios
``` ```

View File

@ -3,22 +3,38 @@
<head> <head>
<meta charset="utf-8" /> <meta charset="utf-8" />
<title>Welcome to Candle!</title> <title>Welcome to Candle!</title>
<link data-trunk rel="copy-file" href="jfk.wav" />
<link data-trunk rel="copy-file" href="mm0.wav" />
<link data-trunk rel="copy-file" href="a13.wav" />
<link data-trunk rel="copy-file" href="gb0.wav" />
<link data-trunk rel="copy-file" href="gb1.wav" />
<link data-trunk rel="copy-file" href="hp0.wav" />
<link data-trunk rel="copy-file" href="tokenizer.en.json" />
<link data-trunk rel="copy-file" href="mel_filters.safetensors" /> <link data-trunk rel="copy-file" href="mel_filters.safetensors" />
<link data-trunk rel="copy-file" href="tiny.en.safetensors" /> <!-- samples -->
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" /> <link data-trunk rel="copy-dir" href="audios" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" /> <!-- tiny.en -->
<link data-trunk rel="copy-dir" href="whisper-tiny.en" />
<!-- tiny -->
<link data-trunk rel="copy-dir" href="whisper-tiny" />
<!-- quantized -->
<link data-trunk rel="copy-dir" href="quantized" />
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic"> <link
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css"> data-trunk
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css"> rel="rust"
href="Cargo.toml"
data-bin="app"
data-type="main" />
<link
data-trunk
rel="rust"
href="Cargo.toml"
data-bin="worker"
data-type="worker" />
<link
rel="stylesheet"
href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css" />
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css" />
</head> </head>
<body></body> <body></body>
</html> </html>

View File

@ -26,9 +26,30 @@
// models base url // models base url
const MODELS = { const MODELS = {
tiny_multilingual: {
base_url: "https://huggingface.co/openai/whisper-tiny/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
},
tiny_en: { tiny_en: {
base_url: base_url:
"https://huggingface.co/openai/whisper-tiny.en/resolve/refs%2Fpr%2F17/", "https://huggingface.co/openai/whisper-tiny.en/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
},
tiny_quantized_multilingual_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny.json",
config: "config-tiny.json",
},
tiny_en_quantized_q80: {
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
model: "model-tiny-q80.gguf",
tokenizer: "tokenizer-tiny-en.json",
config: "config-tiny-en.json",
}, },
}; };
const whisperWorker = new Worker("./whisperWorker.js", { const whisperWorker = new Worker("./whisperWorker.js", {
@ -39,6 +60,7 @@
weightsURL, // URL to the weights file weightsURL, // URL to the weights file
modelID, // model ID modelID, // model ID
tokenizerURL, // URL to the tokenizer file tokenizerURL, // URL to the tokenizer file
configURL, // model config URL
mel_filtersURL, // URL to the mel filters file mel_filtersURL, // URL to the mel filters file
audioURL, // URL to the audio file audioURL, // URL to the audio file
updateStatus // function to update the status updateStatus // function to update the status
@ -48,6 +70,7 @@
weightsURL, weightsURL,
modelID, modelID,
tokenizerURL, tokenizerURL,
configURL,
mel_filtersURL, mel_filtersURL,
audioURL, audioURL,
}); });
@ -128,13 +151,16 @@
return; return;
} }
const modelID = document.querySelector("#model").value; const modelID = document.querySelector("#model").value;
const modelURL = MODELS[modelID].base_url + "model.safetensors"; const model = MODELS[modelID];
const tokenizerURL = MODELS[modelID].base_url + "tokenizer.json"; const modelURL = model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const configURL = model.base_url + model.config;
classifyAudio( classifyAudio(
modelURL, modelURL,
modelID, modelID,
tokenizerURL, tokenizerURL,
configURL,
"mel_filters.safetensors", "mel_filters.safetensors",
audioURL, audioURL,
updateStatus updateStatus
@ -178,8 +204,7 @@
<a <a
href="https://huggingface.co/openai/" href="https://huggingface.co/openai/"
target="_blank" target="_blank"
class="underline hover:text-blue-500 hover:no-underline" class="underline hover:text-blue-500 hover:no-underline">
>
OpenAI Whisper models OpenAI Whisper models
</a> </a>
and WASM runtime built with and WASM runtime built with
@ -196,37 +221,38 @@
<label for="model" class="font-medium">Models Options: </label> <label for="model" class="font-medium">Models Options: </label>
<select <select
id="model" id="model"
class="border-2 border-gray-500 rounded-md font-light" class="border-2 border-gray-500 rounded-md font-light">
> <option value="tiny_multilingual" selected>tiny (151 MB)</option>
<option value="tiny_en" selected>tiny.en (151 MB)</option> <option value="tiny_en" selected>tiny.en (151 MB)</option>
<option value="tiny_quantized_multilingual_q80">
tiny quantized q80 (41.5 MB)
</option>
<option value="tiny_en_quantized_q80">
tiny.en quantized q80 (41.8 MB)
</option>
</select> </select>
</div> </div>
<!-- drag and drop area --> <!-- drag and drop area -->
<div class="relative"> <div class="relative">
<div <div
id="drop-area" id="drop-area"
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden" class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden">
>
<div <div
class="flex flex-col items-center justify-center space-y-1 text-center" class="flex flex-col items-center justify-center space-y-1 text-center">
>
<svg <svg
width="25" width="25"
height="25" height="25"
viewBox="0 0 25 25" viewBox="0 0 25 25"
fill="none" fill="none"
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg">
>
<path <path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z" d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
fill="#000" fill="#000" />
/>
</svg> </svg>
<div class="flex text-sm text-gray-600"> <div class="flex text-sm text-gray-600">
<label <label
for="file-upload" for="file-upload"
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700" class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
>
<span>Drag and drop your audio here</span> <span>Drag and drop your audio here</span>
<span class="block text-xs">or</span> <span class="block text-xs">or</span>
<span class="block text-xs">Click to upload</span> <span class="block text-xs">Click to upload</span>
@ -237,15 +263,13 @@
name="file-upload" name="file-upload"
type="file" type="file"
accept="audio/*" accept="audio/*"
class="sr-only" class="sr-only" />
/>
</div> </div>
<audio <audio
id="audio" id="audio"
hidden hidden
controls controls
class="w-full p-2 select-none" class="w-full p-2 select-none"></audio>
></audio>
</div> </div>
</div> </div>
<div> <div>
@ -253,43 +277,37 @@
<h3 class="font-medium">Examples:</h3> <h3 class="font-medium">Examples:</h3>
<button <button
data-value="samples_jfk.wav" data-value="samples_jfk.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline" class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
>
<span>jfk.wav</span> <span>jfk.wav</span>
<span class="text-xs block"> (352 kB)</span> <span class="text-xs block"> (352 kB)</span>
</button> </button>
<button <button
data-value="samples_a13.wav" data-value="samples_a13.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline" class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
>
<span>a13.wav</span> <span>a13.wav</span>
<span class="text-xs block"> (960 kB)</span> <span class="text-xs block"> (960 kB)</span>
</button> </button>
<button <button
data-value="samples_mm0.wav" data-value="samples_mm0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline" class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
>
<span>mm0.wav</span> <span>mm0.wav</span>
<span class="text-xs block new"> (957 kB)</span> <span class="text-xs block new"> (957 kB)</span>
</button> </button>
<button <button
data-value="samples_gb0.wav" data-value="samples_gb0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline" class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
>
<span>gb0.wav </span> <span>gb0.wav </span>
<span class="text-xs block">(4.08 MB)</span> <span class="text-xs block">(4.08 MB)</span>
</button> </button>
<button <button
data-value="samples_gb1.wav" data-value="samples_gb1.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline" class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
>
<span>gb1.wav </span> <span>gb1.wav </span>
<span class="text-xs block">(6.36 MB)</span> <span class="text-xs block">(6.36 MB)</span>
</button> </button>
<button <button
data-value="samples_hp0.wav" data-value="samples_hp0.wav"
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline" class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
>
<span>hp0.wav </span> <span>hp0.wav </span>
<span class="text-xs block">(8.75 MB)</span> <span class="text-xs block">(8.75 MB)</span>
</button> </button>
@ -300,16 +318,14 @@
<button <button
id="detect" id="detect"
disabled disabled
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed" class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
>
Transcribe Audio Transcribe Audio
</button> </button>
</div> </div>
<div> <div>
<h3 class="font-medium">Transcription:</h3> <h3 class="font-medium">Transcription:</h3>
<div <div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2" class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
>
<p hidden id="output-generation" class="grid-rows-2"></p> <p hidden id="output-generation" class="grid-rows-2"></p>
<span id="output-status" class="m-auto font-light" <span id="output-status" class="m-auto font-light"
>No transcription results yet</span >No transcription results yet</span

View File

@ -7,7 +7,12 @@ use yew::{html, Component, Context, Html};
use yew_agent::{Bridge, Bridged}; use yew_agent::{Bridge, Bridged};
const SAMPLE_NAMES: [&str; 6] = [ const SAMPLE_NAMES: [&str; 6] = [
"jfk.wav", "a13.wav", "gb0.wav", "gb1.wav", "hp0.wav", "mm0.wav", "audios/samples_jfk.wav",
"audios/samples_a13.wav",
"audios/samples_gb0.wav",
"audios/samples_gb1.wav",
"audios/samples_hp0.wav",
"audios/samples_mm0.wav",
]; ];
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> { async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
@ -54,14 +59,46 @@ pub struct App {
} }
async fn model_data_load() -> Result<ModelData, JsValue> { async fn model_data_load() -> Result<ModelData, JsValue> {
let tokenizer = fetch_url("tokenizer.en.json").await?; let quantized = false;
let mel_filters = fetch_url("mel_filters.safetensors").await?; let is_multilingual = false;
let weights = fetch_url("tiny.en.safetensors").await?;
let (tokenizer, mel_filters, weights, config) = if quantized {
console_log!("loading quantized weights");
let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?;
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?;
let config = fetch_url("quantized/config-tiny-en.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
console_log!("loading float weights");
if is_multilingual {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny/model.safetensors").await?;
let config = fetch_url("whisper-tiny/config.json").await?;
(tokenizer, mel_filters, weights, config)
} else {
let mel_filters = fetch_url("mel_filters.safetensors").await?;
let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?;
let weights = fetch_url("whisper-tiny.en/model.safetensors").await?;
let config = fetch_url("whisper-tiny.en/config.json").await?;
(tokenizer, mel_filters, weights, config)
}
};
let timestamps = true;
let _task = Some("transcribe".to_string());
console_log!("{}", weights.len()); console_log!("{}", weights.len());
Ok(ModelData { Ok(ModelData {
tokenizer, tokenizer,
mel_filters, mel_filters,
weights, weights,
config,
quantized,
timestamps,
task: None,
is_multilingual,
language: None,
}) })
} }

View File

@ -168,7 +168,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
let n_len = samples.len() / fft_step; let n_len = samples.len() / fft_step;
// pad audio with at least one extra chunk of zeros // pad audio with at least one extra chunk of zeros
let pad = 100 * worker::CHUNK_LENGTH / 2; let pad = 100 * worker::m::CHUNK_LENGTH / 2;
let n_len = if n_len % pad != 0 { let n_len = if n_len % pad != 0 {
(n_len / pad + 1) * pad (n_len / pad + 1) * pad
} else { } else {
@ -206,9 +206,9 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
let mel = log_mel_spectrogram_( let mel = log_mel_spectrogram_(
samples, samples,
filters, filters,
worker::N_FFT, worker::m::N_FFT,
worker::HOP_LENGTH, worker::m::HOP_LENGTH,
worker::N_MELS, worker::m::N_MELS,
false, false,
); );
Ok(mel) Ok(mel)

View File

@ -9,15 +9,28 @@ pub struct Decoder {
#[wasm_bindgen] #[wasm_bindgen]
impl Decoder { impl Decoder {
#[wasm_bindgen(constructor)] #[wasm_bindgen(constructor)]
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
weights: Vec<u8>, weights: Vec<u8>,
tokenizer: Vec<u8>, tokenizer: Vec<u8>,
mel_filters: Vec<u8>, mel_filters: Vec<u8>,
config: Vec<u8>,
quantized: bool,
is_multilingual: bool,
timestamps: bool,
task: Option<String>,
language: Option<String>,
) -> Result<Decoder, JsError> { ) -> Result<Decoder, JsError> {
let decoder = D::load(ModelData { let decoder = D::load(ModelData {
tokenizer, tokenizer,
mel_filters, mel_filters,
config,
quantized,
weights, weights,
is_multilingual,
timestamps,
task,
language,
}); });
match decoder { match decoder {
@ -32,7 +45,6 @@ impl Decoder {
.decoder .decoder
.convert_and_run(&wav_input) .convert_and_run(&wav_input)
.map_err(|e| JsError::new(&e.to_string()))?; .map_err(|e| JsError::new(&e.to_string()))?;
let json = serde_json::to_string(&segments)?; let json = serde_json::to_string(&segments)?;
Ok(json) Ok(json)
} }

View File

@ -0,0 +1,101 @@
pub const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];

View File

@ -4,14 +4,14 @@ struct Timer {
label: &'static str, label: &'static str,
} }
impl Timer { // impl Timer {
fn new(label: &'static str) -> Self { // fn new(label: &'static str) -> Self {
if WITH_TIMER { // if WITH_TIMER {
web_sys::console::time_with_label(label); // web_sys::console::time_with_label(label);
} // }
Self { label } // Self { label }
} // }
} // }
impl Drop for Timer { impl Drop for Timer {
fn drop(&mut self) { fn drop(&mut self) {
@ -23,7 +23,7 @@ impl Drop for Timer {
mod app; mod app;
mod audio; mod audio;
mod model; pub mod languages;
pub mod worker; pub mod worker;
pub use app::App; pub use app::App;
pub use worker::Worker; pub use worker::Worker;

View File

@ -1,417 +0,0 @@
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_mel_bins: usize, // n_mels
pub max_source_positions: usize, // n_audio_ctx
pub d_model: usize, // n_audio_state
pub encoder_attention_heads: usize, // n_audio_head
pub encoder_layers: usize, // n_audio_layer
pub vocab_size: usize, // n_vocab
pub max_target_positions: usize, // n_text_ctx
// pub n_text_state: usize,
pub decoder_attention_heads: usize, // n_text_head
pub decoder_layers: usize, // n_text_layer
}
impl Config {
pub fn tiny_en() -> Self {
Self {
num_mel_bins: 80,
vocab_size: 51864,
max_source_positions: 1500,
d_model: 384,
encoder_attention_heads: 6,
encoder_layers: 4,
max_target_positions: 448,
// n_text_state: 384,
decoder_attention_heads: 6,
decoder_layers: 4,
}
}
}
// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm
// specific monitoring.
#[derive(Debug)]
struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl Linear {
fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _timer = crate::Timer::new("Linear::forward");
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = {
let _timer = crate::Timer::new("Linear::matmul");
x.matmul(&w)?
};
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn conv1d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
let bias = vb.get(out_channels, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
Ok(LayerNorm::new(weight, bias, 1e-5))
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
struct MultiHeadAttention {
query: Linear,
key: Linear,
value: Linear,
out: Linear,
n_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
Ok(Self {
query,
key,
value,
out,
n_head,
kv_cache: None,
})
}
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_cache: bool,
) -> Result<Tensor> {
let _timer = crate::Timer::new("MultiHeadAttention::forward");
let q = self.query.forward(x)?;
let (k, v) = match xa {
None => {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
(k, v)
}
Some(x) => {
if flush_cache {
self.kv_cache = None;
}
if let Some((k, v)) = &self.kv_cache {
(k.clone(), v.clone())
} else {
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
self.kv_cache = Some((k.clone(), v.clone()));
(k, v)
}
}
};
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;
Ok(out)
}
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
fn qkv_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let (_, n_ctx, n_state) = q.dims3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let mut qk = {
let _timer = crate::Timer::new("qk::matmul");
q.matmul(&k)?
};
if let Some(mask) = mask {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = {
let _timer = crate::Timer::new("qk::softmax");
candle_nn::ops::softmax(&qk, candle::D::Minus1)?
};
let wv = {
let _timer = crate::Timer::new("wv::matmul");
w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?
};
Ok(wv)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
Some((cross_attn, cross_attn_ln))
} else {
None
};
let n_mlp = n_state * 4;
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
Ok(Self {
attn,
attn_ln,
cross_attn,
mlp_linear1,
mlp_linear2,
mlp_ln,
})
}
fn forward(
&mut self,
x: &Tensor,
xa: Option<&Tensor>,
mask: Option<&Tensor>,
flush_kv_cache: bool,
) -> Result<Tensor> {
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
let attn = self
.attn
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &mut self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
.mlp_linear1
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
Ok((x + mlp)?)
}
}
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales: Vec<_> = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect();
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
.to_dtype(candle::DType::F32)?
.unsqueeze(1)?;
let sh = (length, channels / 2);
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
Ok(sincos)
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
pub struct AudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
}
impl AudioEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
let cfg1 = Conv1dConfig {
padding: 1,
stride: 1,
groups: 1,
dilation: 1,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
Ok(Self {
conv1,
conv2,
positional_embedding,
blocks,
ln_post,
})
}
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let _timer = crate::Timer::new("AudioEncoder::forward");
let x = {
let _timer = crate::Timer::new("conv1::forward");
self.conv1.forward(x)?.gelu()?
};
let x = {
let _timer = crate::Timer::new("conv2::forward");
self.conv2.forward(&x)?.gelu()?
};
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter_mut() {
x = block.forward(&x, None, None, flush_kv_cache)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
pub struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
}
impl TextDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let _timer = crate::Timer::new("TextDecoder::forward");
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
let blocks = (0..cfg.decoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
Ok(Self {
token_embedding,
positional_embedding,
blocks,
ln,
mask,
})
}
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter_mut() {
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
}
let x = self.ln.forward(&x)?;
let w = self
.token_embedding
.embeddings()
.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
Ok(logits)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
pub struct Whisper {
pub encoder: AudioEncoder,
pub decoder: TextDecoder,
pub config: Config,
}
impl Whisper {
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
Ok(Self {
encoder,
decoder,
config,
})
}
}

View File

@ -1,7 +1,8 @@
use crate::model::{Config, Whisper}; use crate::languages::LANGUAGES;
use anyhow::Error as E; use anyhow::Error as E;
use candle::{safetensors::Load, DType, Device, Tensor}; use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
use candle_nn::{ops::softmax, VarBuilder}; use candle_nn::{ops::softmax, VarBuilder};
pub use candle_transformers::models::whisper::{self as m, Config};
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng}; use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -25,38 +26,46 @@ macro_rules! console_log {
pub const DTYPE: DType = DType::F32; pub const DTYPE: DType = DType::F32;
// Audio parameters. pub enum Model {
pub const SAMPLE_RATE: usize = 16000; Normal(m::model::Whisper),
pub const N_FFT: usize = 400; Quantized(m::quantized_model::Whisper),
pub const N_MELS: usize = 80; }
pub const HOP_LENGTH: usize = 160;
pub const CHUNK_LENGTH: usize = 30;
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
pub const NO_SPEECH_THRESHOLD: f64 = 0.6; // Maybe we should use some traits rather than doing the dispatch for all these.
pub const LOGPROB_THRESHOLD: f64 = -1.0; impl Model {
pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; pub fn config(&self) -> &Config {
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}
// Tokenizer dependent bits. pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
const SOT_TOKEN: &str = "<|startoftranscript|>"; match self {
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; Self::Normal(m) => m.encoder.forward(x, flush),
const TRANSLATE_TOKEN: &str = "<|translate|>"; Self::Quantized(m) => m.encoder.forward(x, flush),
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; }
const EOT_TOKEN: &str = "<|endoftext|>"; }
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
// From the _get_suppress_tokens function + 50362 (no timestamp) pub fn decoder_forward(
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 &mut self,
pub const SUPPRESS_TOKENS: [u32; 91] = [ x: &Tensor,
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, xa: &Tensor,
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, flush: bool,
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, ) -> candle::Result<Tensor> {
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, match self {
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, Self::Normal(m) => m.decoder.forward(x, xa, flush),
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, Self::Quantized(m) => m.decoder.forward(x, xa, flush),
]; }
}
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecodingResult { pub struct DecodingResult {
@ -77,8 +86,13 @@ pub struct Segment {
#[allow(unused)] #[allow(unused)]
pub struct Decoder { pub struct Decoder {
model: Whisper, model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
mel_filters: Vec<f32>, mel_filters: Vec<f32>,
timestamps: bool,
tokenizer: Tokenizer, tokenizer: Tokenizer,
suppress_tokens: Tensor, suppress_tokens: Tensor,
sot_token: u32, sot_token: u32,
@ -90,32 +104,43 @@ pub struct Decoder {
} }
impl Decoder { impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new( fn new(
model: Whisper, model: Model,
tokenizer: Tokenizer, tokenizer: Tokenizer,
mel_filters: Vec<f32>, mel_filters: Vec<f32>,
device: &Device, device: &Device,
task: Option<Task>,
language: Option<String>,
is_multilingual: bool,
timestamps: bool,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| { .map(|i| {
if SUPPRESS_TOKENS.contains(&i) { if model.config().suppress_tokens.contains(&i) {
f32::NEG_INFINITY f32::NEG_INFINITY
} else { } else {
0f32 0f32
} }
}) })
.collect(); .collect();
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
let seed = 299792458;
Ok(Self { Ok(Self {
model, model,
mel_filters, rng: StdRng::seed_from_u64(seed),
tokenizer, tokenizer,
mel_filters,
task,
timestamps,
language,
is_multilingual,
suppress_tokens, suppress_tokens,
sot_token, sot_token,
transcribe_token, transcribe_token,
@ -126,40 +151,73 @@ impl Decoder {
}) })
} }
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> { fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
let model = &mut self.model; let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?; let language_token = match (self.is_multilingual, &self.language) {
console_log!("audio features: {:?}", audio_features.dims()); (true, None) => Some(detect_language(model, &self.tokenizer, mel)?),
let sample_len = model.config.max_target_positions / 2; (false, None) => None,
(true, Some(language)) => {
match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
}
}
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
let audio_features = model.encoder_forward(mel, true)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64; let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN; let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token, self.transcribe_token]; let mut tokens = vec![self.sot_token];
if let Some(language_token) = language_token {
tokens.push(language_token);
}
match self.task {
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
}
for i in 0..sample_len { for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle // The model expects a batch dim but this inference loop does not handle
// it so we add it at this point. // it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?; let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?; let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
let logits = logits.squeeze(0)?;
// Extract the no speech probability on the first iteration by looking at the first // Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token. // token logits and the probability for the according token.
if i == 0 { if i == 0 {
no_speech_prob = softmax(&logits.get(0)?, 0)? let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
.get(self.no_speech_token as usize)? no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64; .to_scalar::<f32>()? as f64;
} }
let (seq_len, _) = logits.dims2()?; let (_, seq_len, _) = ys.dims3()?;
let logits = logits let logits = model
.get(seq_len - 1)? .decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.broadcast_add(&self.suppress_tokens)?; .i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 { let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?; let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(rng) as u32 distr.sample(&mut self.rng) as u32
} else { } else {
let logits_v: Vec<f32> = logits.to_vec1()?; let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v logits_v
@ -171,9 +229,9 @@ impl Decoder {
}; };
tokens.push(next_token); tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)? let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)? .i(next_token as usize)?
.to_scalar::<f32>()? as f64; .to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions { if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break; break;
} }
sum_logprob += prob.ln(); sum_logprob += prob.ln();
@ -191,22 +249,18 @@ impl Decoder {
}) })
} }
fn decode_with_fallback( fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
&mut self, for (i, &t) in m::TEMPERATURES.iter().enumerate() {
segment: &Tensor, let dr: Result<DecodingResult, _> = self.decode(segment, t);
rng: &mut StdRng, if i == m::TEMPERATURES.len() - 1 {
) -> anyhow::Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
if i == TEMPERATURES.len() - 1 {
return dr; return dr;
} }
// On errors, we try again with a different temperature. // On errors, we try again with a different temperature.
match dr { match dr {
Ok(dr) => { Ok(dr) => {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD; || dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr); return Ok(dr);
} }
} }
@ -219,18 +273,17 @@ impl Decoder {
} }
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> { fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
let mut rng = StdRng::seed_from_u64(299792458);
let (_, _, content_frames) = mel.dims3()?; let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0; let mut seek = 0;
let mut segments = vec![]; let mut segments = vec![];
while seek < content_frames { while seek < content_frames {
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES); let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?; let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment, &mut rng)?; let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size; seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
console_log!("no speech detected, skipping {seek} {dr:?}"); console_log!("no speech detected, skipping {seek} {dr:?}");
continue; continue;
} }
@ -247,17 +300,39 @@ impl Decoder {
pub fn load(md: ModelData) -> anyhow::Result<Self> { pub fn load(md: ModelData) -> anyhow::Result<Self> {
let device = Device::Cpu; let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?; let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?;
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?; let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape()); console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?; let config: Config = serde_json::from_slice(&md.config)?;
let config = Config::tiny_en(); let model = if md.quantized {
let whisper = Whisper::load(&vb, config)?; let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?;
Model::Normal(m::model::Whisper::load(&vb, config)?)
};
console_log!("done loading model"); console_log!("done loading model");
let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?;
let task = match md.task.as_deref() {
Some("translate") => Some(Task::Translate),
_ => Some(Task::Transcribe),
};
let decoder = Self::new(
model,
tokenizer,
mel_filters,
&device,
task,
md.language,
md.is_multilingual,
md.timestamps,
)?;
Ok(decoder) Ok(decoder)
} }
@ -266,8 +341,8 @@ impl Decoder {
let mut wav_input = std::io::Cursor::new(wav_input); let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?; let (header, data) = wav::read(&mut wav_input)?;
console_log!("loaded wav data: {header:?}"); console_log!("loaded wav data: {header:?}");
if header.sampling_rate != SAMPLE_RATE as u32 { if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate"); anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
} }
let data = data.as_sixteen().expect("expected 16 bit wav file"); let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@ -277,27 +352,74 @@ impl Decoder {
console_log!("pcm data loaded {}", pcm_data.len()); console_log!("pcm data loaded {}", pcm_data.len());
let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?; let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
let mel_len = mel.len(); let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
console_log!("loaded mel: {:?}", mel.dims()); console_log!("loaded mel: {:?}", mel.dims());
let segments = self.run(&mel)?; let segments = self.run(&mel)?;
Ok(segments) Ok(segments)
} }
} }
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32, E> {
console_log!("detecting language");
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
.map(|e| e.map_err(E::msg))
.collect::<Result<Vec<_>, E>>()?;
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
let token = &format!("<|{}|>", probs[0].0 .0);
let language = token_id(tokenizer, token)?;
console_log!("detected language: {language} {token}");
Ok(language)
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> { pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) { match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"), None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id), Some(id) => Ok(id),
} }
} }
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
pub enum Task {
Transcribe,
Translate,
}
// Communication to the worker happens through bincode, the model weights and configs are fetched // Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transfered via the following structure. // on the main thread and transfered via the following structure.
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ModelData { pub struct ModelData {
pub weights: Vec<u8>,
pub tokenizer: Vec<u8>, pub tokenizer: Vec<u8>,
pub mel_filters: Vec<u8>, pub mel_filters: Vec<u8>,
pub weights: Vec<u8>, pub config: Vec<u8>,
pub quantized: bool,
pub timestamps: bool,
pub is_multilingual: bool,
pub language: Option<String>,
pub task: Option<String>,
} }
pub struct Worker { pub struct Worker {

View File

@ -17,23 +17,46 @@ class Whisper {
static instance = {}; static instance = {};
// Retrieve the Whisper model. When called for the first time, // Retrieve the Whisper model. When called for the first time,
// this will load the model and save it for future use. // this will load the model and save it for future use.
static async getInstance(weightsURL, modelID, tokenizerURL, mel_filtersURL) { static async getInstance(params) {
const {
weightsURL,
modelID,
tokenizerURL,
mel_filtersURL,
configURL,
quantized,
is_multilingual,
timestamps,
task,
language,
} = params;
// load individual modelID only once // load individual modelID only once
if (!this.instance[modelID]) { if (!this.instance[modelID]) {
await init(); await init();
self.postMessage({ status: "loading", message: "Loading Model" }); self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] = const [
await Promise.all([ weightsArrayU8,
fetchArrayBuffer(weightsURL), tokenizerArrayU8,
fetchArrayBuffer(tokenizerURL), mel_filtersArrayU8,
fetchArrayBuffer(mel_filtersURL), configArrayU8,
]); ] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(mel_filtersURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new Decoder( this.instance[modelID] = new Decoder(
weightsArrayU8, weightsArrayU8,
tokenizerArrayU8, tokenizerArrayU8,
mel_filtersArrayU8 mel_filtersArrayU8,
configArrayU8,
quantized,
is_multilingual,
timestamps,
task,
language
); );
} else { } else {
self.postMessage({ status: "loading", message: "Model Already Loaded" }); self.postMessage({ status: "loading", message: "Model Already Loaded" });
@ -43,17 +66,37 @@ class Whisper {
} }
self.addEventListener("message", async (event) => { self.addEventListener("message", async (event) => {
const { weightsURL, modelID, tokenizerURL, mel_filtersURL, audioURL } = const {
event.data; weightsURL,
modelID,
tokenizerURL,
configURL,
mel_filtersURL,
audioURL,
} = event.data;
try { try {
self.postMessage({ status: "decoding", message: "Starting Decoder" }); self.postMessage({ status: "decoding", message: "Starting Decoder" });
let quantized = false;
const decoder = await Whisper.getInstance( if (modelID.includes("quantized")) {
quantized = true;
}
let is_multilingual = false;
if (modelID.includes("multilingual")) {
is_multilingual = true;
}
let timestamps = true;
const decoder = await Whisper.getInstance({
weightsURL, weightsURL,
modelID, modelID,
tokenizerURL, tokenizerURL,
mel_filtersURL mel_filtersURL,
); configURL,
quantized,
is_multilingual,
timestamps,
task: null,
language: null,
});
self.postMessage({ status: "decoding", message: "Loading Audio" }); self.postMessage({ status: "decoding", message: "Loading Audio" });
const audioArrayU8 = await fetchArrayBuffer(audioURL); const audioArrayU8 = await fetchArrayBuffer(audioURL);