mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Whisper quantized wasm (#1028)
* [Whisper] Update to use quantized model * [whisper] add language detection * [whisper] change assets location * [whisper] adapt js example with quantized models * [whisper] better task parsing * [whisper] minor fixes
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@ -29,9 +29,10 @@ trace-*.json
|
||||
candle-wasm-examples/*/build
|
||||
candle-wasm-examples/*/*.bin
|
||||
candle-wasm-examples/*/*.jpeg
|
||||
candle-wasm-examples/*/*.wav
|
||||
candle-wasm-examples/*/*.safetensors
|
||||
candle-wasm-examples/*/audios/*.wav
|
||||
candle-wasm-examples/**/*.safetensors
|
||||
candle-wasm-examples/**/*.gguf
|
||||
candle-wasm-examples/*/package-lock.json
|
||||
|
||||
candle-wasm-examples/**/config*.json
|
||||
.DS_Store
|
||||
.idea/*
|
||||
|
@ -11,6 +11,7 @@ license.workspace = true
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
|
@ -10,19 +10,31 @@ From the `candle-wasm-examples/whisper` directory run:
|
||||
Download assets:
|
||||
|
||||
```bash
|
||||
# Model and tokenizer
|
||||
# mel filters
|
||||
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
|
||||
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tiny.en.safetensors
|
||||
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tokenizer.en.json
|
||||
# Model and tokenizer tiny.en
|
||||
wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en
|
||||
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en
|
||||
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en
|
||||
# model and tokenizer tiny multilanguage
|
||||
wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny
|
||||
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny
|
||||
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny
|
||||
|
||||
#quantized
|
||||
wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized
|
||||
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized
|
||||
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized
|
||||
|
||||
|
||||
|
||||
# Audio samples
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -O gb0.wav
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -O a13.wav
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -O gb1.wav
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -O hp0.wav
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -O jfk.wav
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -O mm0.wav
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios
|
||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios
|
||||
|
||||
```
|
||||
|
||||
|
@ -3,22 +3,38 @@
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Welcome to Candle!</title>
|
||||
|
||||
<link data-trunk rel="copy-file" href="jfk.wav" />
|
||||
<link data-trunk rel="copy-file" href="mm0.wav" />
|
||||
<link data-trunk rel="copy-file" href="a13.wav" />
|
||||
<link data-trunk rel="copy-file" href="gb0.wav" />
|
||||
<link data-trunk rel="copy-file" href="gb1.wav" />
|
||||
<link data-trunk rel="copy-file" href="hp0.wav" />
|
||||
<link data-trunk rel="copy-file" href="tokenizer.en.json" />
|
||||
<link data-trunk rel="copy-file" href="mel_filters.safetensors" />
|
||||
<link data-trunk rel="copy-file" href="tiny.en.safetensors" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
|
||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
|
||||
<!-- samples -->
|
||||
<link data-trunk rel="copy-dir" href="audios" />
|
||||
<!-- tiny.en -->
|
||||
<link data-trunk rel="copy-dir" href="whisper-tiny.en" />
|
||||
<!-- tiny -->
|
||||
<link data-trunk rel="copy-dir" href="whisper-tiny" />
|
||||
<!-- quantized -->
|
||||
<link data-trunk rel="copy-dir" href="quantized" />
|
||||
|
||||
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
|
||||
<link
|
||||
data-trunk
|
||||
rel="rust"
|
||||
href="Cargo.toml"
|
||||
data-bin="app"
|
||||
data-type="main" />
|
||||
<link
|
||||
data-trunk
|
||||
rel="rust"
|
||||
href="Cargo.toml"
|
||||
data-bin="worker"
|
||||
data-type="worker" />
|
||||
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic" />
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css" />
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css" />
|
||||
</head>
|
||||
<body></body>
|
||||
</html>
|
||||
|
@ -26,9 +26,30 @@
|
||||
|
||||
// models base url
|
||||
const MODELS = {
|
||||
tiny_multilingual: {
|
||||
base_url: "https://huggingface.co/openai/whisper-tiny/resolve/main/",
|
||||
model: "model.safetensors",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
},
|
||||
tiny_en: {
|
||||
base_url:
|
||||
"https://huggingface.co/openai/whisper-tiny.en/resolve/refs%2Fpr%2F17/",
|
||||
"https://huggingface.co/openai/whisper-tiny.en/resolve/main/",
|
||||
model: "model.safetensors",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
},
|
||||
tiny_quantized_multilingual_q80: {
|
||||
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
|
||||
model: "model-tiny-q80.gguf",
|
||||
tokenizer: "tokenizer-tiny.json",
|
||||
config: "config-tiny.json",
|
||||
},
|
||||
tiny_en_quantized_q80: {
|
||||
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
|
||||
model: "model-tiny-q80.gguf",
|
||||
tokenizer: "tokenizer-tiny-en.json",
|
||||
config: "config-tiny-en.json",
|
||||
},
|
||||
};
|
||||
const whisperWorker = new Worker("./whisperWorker.js", {
|
||||
@ -39,6 +60,7 @@
|
||||
weightsURL, // URL to the weights file
|
||||
modelID, // model ID
|
||||
tokenizerURL, // URL to the tokenizer file
|
||||
configURL, // model config URL
|
||||
mel_filtersURL, // URL to the mel filters file
|
||||
audioURL, // URL to the audio file
|
||||
updateStatus // function to update the status
|
||||
@ -48,6 +70,7 @@
|
||||
weightsURL,
|
||||
modelID,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
mel_filtersURL,
|
||||
audioURL,
|
||||
});
|
||||
@ -128,13 +151,16 @@
|
||||
return;
|
||||
}
|
||||
const modelID = document.querySelector("#model").value;
|
||||
const modelURL = MODELS[modelID].base_url + "model.safetensors";
|
||||
const tokenizerURL = MODELS[modelID].base_url + "tokenizer.json";
|
||||
const model = MODELS[modelID];
|
||||
const modelURL = model.base_url + model.model;
|
||||
const tokenizerURL = model.base_url + model.tokenizer;
|
||||
const configURL = model.base_url + model.config;
|
||||
|
||||
classifyAudio(
|
||||
modelURL,
|
||||
modelID,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
"mel_filters.safetensors",
|
||||
audioURL,
|
||||
updateStatus
|
||||
@ -178,8 +204,7 @@
|
||||
<a
|
||||
href="https://huggingface.co/openai/"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>
|
||||
class="underline hover:text-blue-500 hover:no-underline">
|
||||
OpenAI Whisper models
|
||||
</a>
|
||||
and WASM runtime built with
|
||||
@ -196,37 +221,38 @@
|
||||
<label for="model" class="font-medium">Models Options: </label>
|
||||
<select
|
||||
id="model"
|
||||
class="border-2 border-gray-500 rounded-md font-light"
|
||||
>
|
||||
class="border-2 border-gray-500 rounded-md font-light">
|
||||
<option value="tiny_multilingual" selected>tiny (151 MB)</option>
|
||||
<option value="tiny_en" selected>tiny.en (151 MB)</option>
|
||||
<option value="tiny_quantized_multilingual_q80">
|
||||
tiny quantized q80 (41.5 MB)
|
||||
</option>
|
||||
<option value="tiny_en_quantized_q80">
|
||||
tiny.en quantized q80 (41.8 MB)
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
<!-- drag and drop area -->
|
||||
<div class="relative">
|
||||
<div
|
||||
id="drop-area"
|
||||
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden"
|
||||
>
|
||||
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden">
|
||||
<div
|
||||
class="flex flex-col items-center justify-center space-y-1 text-center"
|
||||
>
|
||||
class="flex flex-col items-center justify-center space-y-1 text-center">
|
||||
<svg
|
||||
width="25"
|
||||
height="25"
|
||||
viewBox="0 0 25 25"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
xmlns="http://www.w3.org/2000/svg">
|
||||
<path
|
||||
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
|
||||
fill="#000"
|
||||
/>
|
||||
fill="#000" />
|
||||
</svg>
|
||||
<div class="flex text-sm text-gray-600">
|
||||
<label
|
||||
for="file-upload"
|
||||
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
|
||||
>
|
||||
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
|
||||
<span>Drag and drop your audio here</span>
|
||||
<span class="block text-xs">or</span>
|
||||
<span class="block text-xs">Click to upload</span>
|
||||
@ -237,15 +263,13 @@
|
||||
name="file-upload"
|
||||
type="file"
|
||||
accept="audio/*"
|
||||
class="sr-only"
|
||||
/>
|
||||
class="sr-only" />
|
||||
</div>
|
||||
<audio
|
||||
id="audio"
|
||||
hidden
|
||||
controls
|
||||
class="w-full p-2 select-none"
|
||||
></audio>
|
||||
class="w-full p-2 select-none"></audio>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
@ -253,43 +277,37 @@
|
||||
<h3 class="font-medium">Examples:</h3>
|
||||
<button
|
||||
data-value="samples_jfk.wav"
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
||||
>
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||
<span>jfk.wav</span>
|
||||
<span class="text-xs block"> (352 kB)</span>
|
||||
</button>
|
||||
<button
|
||||
data-value="samples_a13.wav"
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
||||
>
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||
<span>a13.wav</span>
|
||||
<span class="text-xs block"> (960 kB)</span>
|
||||
</button>
|
||||
<button
|
||||
data-value="samples_mm0.wav"
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
||||
>
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||
<span>mm0.wav</span>
|
||||
<span class="text-xs block new"> (957 kB)</span>
|
||||
</button>
|
||||
<button
|
||||
data-value="samples_gb0.wav"
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
||||
>
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||
<span>gb0.wav </span>
|
||||
<span class="text-xs block">(4.08 MB)</span>
|
||||
</button>
|
||||
<button
|
||||
data-value="samples_gb1.wav"
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
||||
>
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||
<span>gb1.wav </span>
|
||||
<span class="text-xs block">(6.36 MB)</span>
|
||||
</button>
|
||||
<button
|
||||
data-value="samples_hp0.wav"
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
||||
>
|
||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||
<span>hp0.wav </span>
|
||||
<span class="text-xs block">(8.75 MB)</span>
|
||||
</button>
|
||||
@ -300,16 +318,14 @@
|
||||
<button
|
||||
id="detect"
|
||||
disabled
|
||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
|
||||
>
|
||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
|
||||
Transcribe Audio
|
||||
</button>
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="font-medium">Transcription:</h3>
|
||||
<div
|
||||
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
|
||||
>
|
||||
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
|
||||
<p hidden id="output-generation" class="grid-rows-2"></p>
|
||||
<span id="output-status" class="m-auto font-light"
|
||||
>No transcription results yet</span
|
||||
|
@ -7,7 +7,12 @@ use yew::{html, Component, Context, Html};
|
||||
use yew_agent::{Bridge, Bridged};
|
||||
|
||||
const SAMPLE_NAMES: [&str; 6] = [
|
||||
"jfk.wav", "a13.wav", "gb0.wav", "gb1.wav", "hp0.wav", "mm0.wav",
|
||||
"audios/samples_jfk.wav",
|
||||
"audios/samples_a13.wav",
|
||||
"audios/samples_gb0.wav",
|
||||
"audios/samples_gb1.wav",
|
||||
"audios/samples_hp0.wav",
|
||||
"audios/samples_mm0.wav",
|
||||
];
|
||||
|
||||
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
|
||||
@ -54,14 +59,46 @@ pub struct App {
|
||||
}
|
||||
|
||||
async fn model_data_load() -> Result<ModelData, JsValue> {
|
||||
let tokenizer = fetch_url("tokenizer.en.json").await?;
|
||||
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
||||
let weights = fetch_url("tiny.en.safetensors").await?;
|
||||
let quantized = false;
|
||||
let is_multilingual = false;
|
||||
|
||||
let (tokenizer, mel_filters, weights, config) = if quantized {
|
||||
console_log!("loading quantized weights");
|
||||
let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?;
|
||||
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
||||
let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?;
|
||||
let config = fetch_url("quantized/config-tiny-en.json").await?;
|
||||
(tokenizer, mel_filters, weights, config)
|
||||
} else {
|
||||
console_log!("loading float weights");
|
||||
if is_multilingual {
|
||||
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
||||
let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?;
|
||||
let weights = fetch_url("whisper-tiny/model.safetensors").await?;
|
||||
let config = fetch_url("whisper-tiny/config.json").await?;
|
||||
(tokenizer, mel_filters, weights, config)
|
||||
} else {
|
||||
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
||||
let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?;
|
||||
let weights = fetch_url("whisper-tiny.en/model.safetensors").await?;
|
||||
let config = fetch_url("whisper-tiny.en/config.json").await?;
|
||||
(tokenizer, mel_filters, weights, config)
|
||||
}
|
||||
};
|
||||
|
||||
let timestamps = true;
|
||||
let _task = Some("transcribe".to_string());
|
||||
console_log!("{}", weights.len());
|
||||
Ok(ModelData {
|
||||
tokenizer,
|
||||
mel_filters,
|
||||
weights,
|
||||
config,
|
||||
quantized,
|
||||
timestamps,
|
||||
task: None,
|
||||
is_multilingual,
|
||||
language: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -168,7 +168,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
let n_len = samples.len() / fft_step;
|
||||
|
||||
// pad audio with at least one extra chunk of zeros
|
||||
let pad = 100 * worker::CHUNK_LENGTH / 2;
|
||||
let pad = 100 * worker::m::CHUNK_LENGTH / 2;
|
||||
let n_len = if n_len % pad != 0 {
|
||||
(n_len / pad + 1) * pad
|
||||
} else {
|
||||
@ -206,9 +206,9 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
let mel = log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
worker::N_FFT,
|
||||
worker::HOP_LENGTH,
|
||||
worker::N_MELS,
|
||||
worker::m::N_FFT,
|
||||
worker::m::HOP_LENGTH,
|
||||
worker::m::N_MELS,
|
||||
false,
|
||||
);
|
||||
Ok(mel)
|
||||
|
@ -9,15 +9,28 @@ pub struct Decoder {
|
||||
#[wasm_bindgen]
|
||||
impl Decoder {
|
||||
#[wasm_bindgen(constructor)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
weights: Vec<u8>,
|
||||
tokenizer: Vec<u8>,
|
||||
mel_filters: Vec<u8>,
|
||||
config: Vec<u8>,
|
||||
quantized: bool,
|
||||
is_multilingual: bool,
|
||||
timestamps: bool,
|
||||
task: Option<String>,
|
||||
language: Option<String>,
|
||||
) -> Result<Decoder, JsError> {
|
||||
let decoder = D::load(ModelData {
|
||||
tokenizer,
|
||||
mel_filters,
|
||||
config,
|
||||
quantized,
|
||||
weights,
|
||||
is_multilingual,
|
||||
timestamps,
|
||||
task,
|
||||
language,
|
||||
});
|
||||
|
||||
match decoder {
|
||||
@ -32,7 +45,6 @@ impl Decoder {
|
||||
.decoder
|
||||
.convert_and_run(&wav_input)
|
||||
.map_err(|e| JsError::new(&e.to_string()))?;
|
||||
|
||||
let json = serde_json::to_string(&segments)?;
|
||||
Ok(json)
|
||||
}
|
||||
|
101
candle-wasm-examples/whisper/src/languages.rs
Normal file
101
candle-wasm-examples/whisper/src/languages.rs
Normal file
@ -0,0 +1,101 @@
|
||||
pub const LANGUAGES: [(&str, &str); 99] = [
|
||||
("en", "english"),
|
||||
("zh", "chinese"),
|
||||
("de", "german"),
|
||||
("es", "spanish"),
|
||||
("ru", "russian"),
|
||||
("ko", "korean"),
|
||||
("fr", "french"),
|
||||
("ja", "japanese"),
|
||||
("pt", "portuguese"),
|
||||
("tr", "turkish"),
|
||||
("pl", "polish"),
|
||||
("ca", "catalan"),
|
||||
("nl", "dutch"),
|
||||
("ar", "arabic"),
|
||||
("sv", "swedish"),
|
||||
("it", "italian"),
|
||||
("id", "indonesian"),
|
||||
("hi", "hindi"),
|
||||
("fi", "finnish"),
|
||||
("vi", "vietnamese"),
|
||||
("he", "hebrew"),
|
||||
("uk", "ukrainian"),
|
||||
("el", "greek"),
|
||||
("ms", "malay"),
|
||||
("cs", "czech"),
|
||||
("ro", "romanian"),
|
||||
("da", "danish"),
|
||||
("hu", "hungarian"),
|
||||
("ta", "tamil"),
|
||||
("no", "norwegian"),
|
||||
("th", "thai"),
|
||||
("ur", "urdu"),
|
||||
("hr", "croatian"),
|
||||
("bg", "bulgarian"),
|
||||
("lt", "lithuanian"),
|
||||
("la", "latin"),
|
||||
("mi", "maori"),
|
||||
("ml", "malayalam"),
|
||||
("cy", "welsh"),
|
||||
("sk", "slovak"),
|
||||
("te", "telugu"),
|
||||
("fa", "persian"),
|
||||
("lv", "latvian"),
|
||||
("bn", "bengali"),
|
||||
("sr", "serbian"),
|
||||
("az", "azerbaijani"),
|
||||
("sl", "slovenian"),
|
||||
("kn", "kannada"),
|
||||
("et", "estonian"),
|
||||
("mk", "macedonian"),
|
||||
("br", "breton"),
|
||||
("eu", "basque"),
|
||||
("is", "icelandic"),
|
||||
("hy", "armenian"),
|
||||
("ne", "nepali"),
|
||||
("mn", "mongolian"),
|
||||
("bs", "bosnian"),
|
||||
("kk", "kazakh"),
|
||||
("sq", "albanian"),
|
||||
("sw", "swahili"),
|
||||
("gl", "galician"),
|
||||
("mr", "marathi"),
|
||||
("pa", "punjabi"),
|
||||
("si", "sinhala"),
|
||||
("km", "khmer"),
|
||||
("sn", "shona"),
|
||||
("yo", "yoruba"),
|
||||
("so", "somali"),
|
||||
("af", "afrikaans"),
|
||||
("oc", "occitan"),
|
||||
("ka", "georgian"),
|
||||
("be", "belarusian"),
|
||||
("tg", "tajik"),
|
||||
("sd", "sindhi"),
|
||||
("gu", "gujarati"),
|
||||
("am", "amharic"),
|
||||
("yi", "yiddish"),
|
||||
("lo", "lao"),
|
||||
("uz", "uzbek"),
|
||||
("fo", "faroese"),
|
||||
("ht", "haitian creole"),
|
||||
("ps", "pashto"),
|
||||
("tk", "turkmen"),
|
||||
("nn", "nynorsk"),
|
||||
("mt", "maltese"),
|
||||
("sa", "sanskrit"),
|
||||
("lb", "luxembourgish"),
|
||||
("my", "myanmar"),
|
||||
("bo", "tibetan"),
|
||||
("tl", "tagalog"),
|
||||
("mg", "malagasy"),
|
||||
("as", "assamese"),
|
||||
("tt", "tatar"),
|
||||
("haw", "hawaiian"),
|
||||
("ln", "lingala"),
|
||||
("ha", "hausa"),
|
||||
("ba", "bashkir"),
|
||||
("jw", "javanese"),
|
||||
("su", "sundanese"),
|
||||
];
|
@ -4,14 +4,14 @@ struct Timer {
|
||||
label: &'static str,
|
||||
}
|
||||
|
||||
impl Timer {
|
||||
fn new(label: &'static str) -> Self {
|
||||
if WITH_TIMER {
|
||||
web_sys::console::time_with_label(label);
|
||||
}
|
||||
Self { label }
|
||||
}
|
||||
}
|
||||
// impl Timer {
|
||||
// fn new(label: &'static str) -> Self {
|
||||
// if WITH_TIMER {
|
||||
// web_sys::console::time_with_label(label);
|
||||
// }
|
||||
// Self { label }
|
||||
// }
|
||||
// }
|
||||
|
||||
impl Drop for Timer {
|
||||
fn drop(&mut self) {
|
||||
@ -23,7 +23,7 @@ impl Drop for Timer {
|
||||
|
||||
mod app;
|
||||
mod audio;
|
||||
mod model;
|
||||
pub mod languages;
|
||||
pub mod worker;
|
||||
pub use app::App;
|
||||
pub use worker::Worker;
|
||||
|
@ -1,417 +0,0 @@
|
||||
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_mel_bins: usize, // n_mels
|
||||
pub max_source_positions: usize, // n_audio_ctx
|
||||
pub d_model: usize, // n_audio_state
|
||||
pub encoder_attention_heads: usize, // n_audio_head
|
||||
pub encoder_layers: usize, // n_audio_layer
|
||||
pub vocab_size: usize, // n_vocab
|
||||
pub max_target_positions: usize, // n_text_ctx
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize, // n_text_head
|
||||
pub decoder_layers: usize, // n_text_layer
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny_en() -> Self {
|
||||
Self {
|
||||
num_mel_bins: 80,
|
||||
vocab_size: 51864,
|
||||
max_source_positions: 1500,
|
||||
d_model: 384,
|
||||
encoder_attention_heads: 6,
|
||||
encoder_layers: 4,
|
||||
max_target_positions: 448,
|
||||
// n_text_state: 384,
|
||||
decoder_attention_heads: 6,
|
||||
decoder_layers: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm
|
||||
// specific monitoring.
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let _timer = crate::Timer::new("Linear::forward");
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = {
|
||||
let _timer = crate::Timer::new("Linear::matmul");
|
||||
x.matmul(&w)?
|
||||
};
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = vb.get(size2, "bias")?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
||||
let bias = vb.get(out_channels, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
Ok(LayerNorm::new(weight, bias, 1e-5))
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||
struct MultiHeadAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
out: Linear,
|
||||
n_head: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
n_head,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("MultiHeadAttention::forward");
|
||||
let q = self.query.forward(x)?;
|
||||
let (k, v) = match xa {
|
||||
None => {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
(k, v)
|
||||
}
|
||||
Some(x) => {
|
||||
if flush_cache {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
if let Some((k, v)) = &self.kv_cache {
|
||||
(k.clone(), v.clone())
|
||||
} else {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||
}
|
||||
|
||||
fn qkv_attention(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (_, n_ctx, n_state) = q.dims3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let mut qk = {
|
||||
let _timer = crate::Timer::new("qk::matmul");
|
||||
q.matmul(&k)?
|
||||
};
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = {
|
||||
let _timer = crate::Timer::new("qk::softmax");
|
||||
candle_nn::ops::softmax(&qk, candle::D::Minus1)?
|
||||
};
|
||||
let wv = {
|
||||
let _timer = crate::Timer::new("wv::matmul");
|
||||
w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?
|
||||
};
|
||||
Ok(wv)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
struct ResidualAttentionBlock {
|
||||
attn: MultiHeadAttention,
|
||||
attn_ln: LayerNorm,
|
||||
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
||||
mlp_linear1: Linear,
|
||||
mlp_linear2: Linear,
|
||||
mlp_ln: LayerNorm,
|
||||
}
|
||||
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||
let cross_attn = if ca {
|
||||
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
||||
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
|
||||
Some((cross_attn, cross_attn_ln))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
||||
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
||||
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
attn_ln,
|
||||
cross_attn,
|
||||
mlp_linear1,
|
||||
mlp_linear2,
|
||||
mlp_ln,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
|
||||
let attn = self
|
||||
.attn
|
||||
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
.mlp_linear1
|
||||
.forward(&self.mlp_ln.forward(&x)?)?
|
||||
.gelu()?,
|
||||
)?;
|
||||
Ok((x + mlp)?)
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
||||
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
||||
Ok(sincos)
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
pub struct AudioEncoder {
|
||||
conv1: Conv1d,
|
||||
conv2: Conv1d,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln_post: LayerNorm,
|
||||
}
|
||||
|
||||
impl AudioEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.encoder_attention_heads;
|
||||
let n_ctx = cfg.max_source_positions;
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln_post,
|
||||
})
|
||||
}
|
||||
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _timer = crate::Timer::new("AudioEncoder::forward");
|
||||
let x = {
|
||||
let _timer = crate::Timer::new("conv1::forward");
|
||||
self.conv1.forward(x)?.gelu()?
|
||||
};
|
||||
let x = {
|
||||
let _timer = crate::Timer::new("conv2::forward");
|
||||
self.conv2.forward(&x)?.gelu()?
|
||||
};
|
||||
let x = x.transpose(1, 2)?;
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
pub struct TextDecoder {
|
||||
token_embedding: Embedding,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let _timer = crate::Timer::new("TextDecoder::forward");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
|
||||
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
|
||||
let blocks = (0..cfg.decoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
let mask: Vec<_> = (0..n_ctx)
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let x_dims = x.dims();
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self
|
||||
.token_embedding
|
||||
.embeddings()
|
||||
.broadcast_left(x_dims[0])?;
|
||||
let logits = x.matmul(&w.t()?)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
pub struct Whisper {
|
||||
pub encoder: AudioEncoder,
|
||||
pub decoder: TextDecoder,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
||||
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
|
||||
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
@ -1,7 +1,8 @@
|
||||
use crate::model::{Config, Whisper};
|
||||
use crate::languages::LANGUAGES;
|
||||
use anyhow::Error as E;
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
pub use candle_transformers::models::whisper::{self as m, Config};
|
||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -25,38 +26,46 @@ macro_rules! console_log {
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
// Audio parameters.
|
||||
pub const SAMPLE_RATE: usize = 16000;
|
||||
pub const N_FFT: usize = 400;
|
||||
pub const N_MELS: usize = 80;
|
||||
pub const HOP_LENGTH: usize = 160;
|
||||
pub const CHUNK_LENGTH: usize = 30;
|
||||
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||
pub enum Model {
|
||||
Normal(m::model::Whisper),
|
||||
Quantized(m::quantized_model::Whisper),
|
||||
}
|
||||
|
||||
pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||
pub const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||
impl Model {
|
||||
pub fn config(&self) -> &Config {
|
||||
match self {
|
||||
Self::Normal(m) => &m.config,
|
||||
Self::Quantized(m) => &m.config,
|
||||
}
|
||||
}
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||
}
|
||||
}
|
||||
|
||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||
pub const SUPPRESS_TOKENS: [u32; 91] = [
|
||||
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
|
||||
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
|
||||
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
|
||||
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
|
||||
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
|
||||
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
||||
];
|
||||
pub fn decoder_forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: &Tensor,
|
||||
flush: bool,
|
||||
) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Normal(m) => m.decoder.final_linear(x),
|
||||
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DecodingResult {
|
||||
@ -77,8 +86,13 @@ pub struct Segment {
|
||||
|
||||
#[allow(unused)]
|
||||
pub struct Decoder {
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
language: Option<String>,
|
||||
is_multilingual: bool,
|
||||
mel_filters: Vec<f32>,
|
||||
timestamps: bool,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
@ -90,32 +104,43 @@ pub struct Decoder {
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Whisper,
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
mel_filters: Vec<f32>,
|
||||
device: &Device,
|
||||
task: Option<Task>,
|
||||
language: Option<String>,
|
||||
is_multilingual: bool,
|
||||
timestamps: bool,
|
||||
) -> anyhow::Result<Self> {
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||
.map(|i| {
|
||||
if SUPPRESS_TOKENS.contains(&i) {
|
||||
if model.config().suppress_tokens.contains(&i) {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0f32
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
let seed = 299792458;
|
||||
Ok(Self {
|
||||
model,
|
||||
mel_filters,
|
||||
rng: StdRng::seed_from_u64(seed),
|
||||
tokenizer,
|
||||
mel_filters,
|
||||
task,
|
||||
timestamps,
|
||||
language,
|
||||
is_multilingual,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
@ -126,40 +151,73 @@ impl Decoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
|
||||
let model = &mut self.model;
|
||||
let audio_features = model.encoder.forward(mel, true)?;
|
||||
console_log!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let language_token = match (self.is_multilingual, &self.language) {
|
||||
(true, None) => Some(detect_language(model, &self.tokenizer, mel)?),
|
||||
(false, None) => None,
|
||||
(true, Some(language)) => {
|
||||
match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) {
|
||||
Ok(token_id) => Some(token_id),
|
||||
Err(_) => anyhow::bail!("language {language} is not supported"),
|
||||
}
|
||||
}
|
||||
(false, Some(_)) => {
|
||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||
}
|
||||
};
|
||||
|
||||
let audio_features = model.encoder_forward(mel, true)?;
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config().max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token, self.transcribe_token];
|
||||
let mut tokens = vec![self.sot_token];
|
||||
if let Some(language_token) = language_token {
|
||||
tokens.push(language_token);
|
||||
}
|
||||
match self.task {
|
||||
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||
}
|
||||
if !self.timestamps {
|
||||
tokens.push(self.no_timestamps_token);
|
||||
}
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
// The model expects a batch dim but this inference loop does not handle
|
||||
// it so we add it at this point.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||
|
||||
// Extract the no speech probability on the first iteration by looking at the first
|
||||
// token logits and the probability for the according token.
|
||||
if i == 0 {
|
||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
||||
.get(self.no_speech_token as usize)?
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
no_speech_prob = softmax(&logits, 0)?
|
||||
.i(self.no_speech_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
}
|
||||
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let logits = logits
|
||||
.get(seq_len - 1)?
|
||||
.broadcast_add(&self.suppress_tokens)?;
|
||||
let (_, seq_len, _) = ys.dims3()?;
|
||||
let logits = model
|
||||
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||
.i(0)?
|
||||
.i(0)?;
|
||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||
// ApplyTimestampRules, i.e.:
|
||||
// - Timestamps come in pairs, except before EOT.
|
||||
// - Timestamps should be non-decreasing.
|
||||
// - If the sum of the probabilities of timestamps is higher than any other tokens,
|
||||
// only consider timestamps when sampling.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
|
||||
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(rng) as u32
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
logits_v
|
||||
@ -171,9 +229,9 @@ impl Decoder {
|
||||
};
|
||||
tokens.push(next_token);
|
||||
let prob = softmax(&logits, candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.i(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
||||
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
@ -191,22 +249,18 @@ impl Decoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_with_fallback(
|
||||
&mut self,
|
||||
segment: &Tensor,
|
||||
rng: &mut StdRng,
|
||||
) -> anyhow::Result<DecodingResult> {
|
||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
|
||||
if i == TEMPERATURES.len() - 1 {
|
||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
|
||||
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult, _> = self.decode(segment, t);
|
||||
if i == m::TEMPERATURES.len() - 1 {
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
match dr {
|
||||
Ok(dr) => {
|
||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
@ -219,18 +273,17 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let (_, _, content_frames) = mel.dims3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
|
||||
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||
console_log!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
@ -247,17 +300,39 @@ impl Decoder {
|
||||
|
||||
pub fn load(md: ModelData) -> anyhow::Result<Self> {
|
||||
let device = Device::Cpu;
|
||||
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
|
||||
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?;
|
||||
|
||||
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
|
||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?;
|
||||
let config = Config::tiny_en();
|
||||
let whisper = Whisper::load(&vb, config)?;
|
||||
let config: Config = serde_json::from_slice(&md.config)?;
|
||||
let model = if md.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
|
||||
&md.weights,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?;
|
||||
Model::Normal(m::model::Whisper::load(&vb, config)?)
|
||||
};
|
||||
console_log!("done loading model");
|
||||
let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?;
|
||||
|
||||
let task = match md.task.as_deref() {
|
||||
Some("translate") => Some(Task::Translate),
|
||||
_ => Some(Task::Transcribe),
|
||||
};
|
||||
|
||||
let decoder = Self::new(
|
||||
model,
|
||||
tokenizer,
|
||||
mel_filters,
|
||||
&device,
|
||||
task,
|
||||
md.language,
|
||||
md.is_multilingual,
|
||||
md.timestamps,
|
||||
)?;
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
@ -266,8 +341,8 @@ impl Decoder {
|
||||
let mut wav_input = std::io::Cursor::new(wav_input);
|
||||
let (header, data) = wav::read(&mut wav_input)?;
|
||||
console_log!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate");
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
@ -277,27 +352,74 @@ impl Decoder {
|
||||
console_log!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
console_log!("loaded mel: {:?}", mel.dims());
|
||||
let segments = self.run(&mel)?;
|
||||
Ok(segments)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the token id for the selected language.
|
||||
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32, E> {
|
||||
console_log!("detecting language");
|
||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||
let mel = mel.narrow(
|
||||
2,
|
||||
0,
|
||||
usize::min(seq_len, model.config().max_source_positions),
|
||||
)?;
|
||||
let device = mel.device();
|
||||
|
||||
let language_token_ids = LANGUAGES
|
||||
.iter()
|
||||
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.map(|e| e.map_err(E::msg))
|
||||
.collect::<Result<Vec<_>, E>>()?;
|
||||
|
||||
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder_forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||
let probs = probs.to_vec1::<f32>()?;
|
||||
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
||||
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for ((_, language), p) in probs.iter().take(5) {
|
||||
println!("{language}: {p}")
|
||||
}
|
||||
let token = &format!("<|{}|>", probs[0].0 .0);
|
||||
let language = token_id(tokenizer, token)?;
|
||||
console_log!("detected language: {language} {token}");
|
||||
Ok(language)
|
||||
}
|
||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
match tokenizer.token_to_id(token) {
|
||||
None => candle::bail!("no token-id for {token}"),
|
||||
Some(id) => Ok(id),
|
||||
}
|
||||
}
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
|
||||
pub enum Task {
|
||||
Transcribe,
|
||||
Translate,
|
||||
}
|
||||
|
||||
// Communication to the worker happens through bincode, the model weights and configs are fetched
|
||||
// on the main thread and transfered via the following structure.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub weights: Vec<u8>,
|
||||
pub tokenizer: Vec<u8>,
|
||||
pub mel_filters: Vec<u8>,
|
||||
pub weights: Vec<u8>,
|
||||
pub config: Vec<u8>,
|
||||
pub quantized: bool,
|
||||
pub timestamps: bool,
|
||||
pub is_multilingual: bool,
|
||||
pub language: Option<String>,
|
||||
pub task: Option<String>,
|
||||
}
|
||||
|
||||
pub struct Worker {
|
||||
|
@ -17,23 +17,46 @@ class Whisper {
|
||||
static instance = {};
|
||||
// Retrieve the Whisper model. When called for the first time,
|
||||
// this will load the model and save it for future use.
|
||||
static async getInstance(weightsURL, modelID, tokenizerURL, mel_filtersURL) {
|
||||
static async getInstance(params) {
|
||||
const {
|
||||
weightsURL,
|
||||
modelID,
|
||||
tokenizerURL,
|
||||
mel_filtersURL,
|
||||
configURL,
|
||||
quantized,
|
||||
is_multilingual,
|
||||
timestamps,
|
||||
task,
|
||||
language,
|
||||
} = params;
|
||||
// load individual modelID only once
|
||||
if (!this.instance[modelID]) {
|
||||
await init();
|
||||
|
||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||
const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
|
||||
await Promise.all([
|
||||
fetchArrayBuffer(weightsURL),
|
||||
fetchArrayBuffer(tokenizerURL),
|
||||
fetchArrayBuffer(mel_filtersURL),
|
||||
]);
|
||||
const [
|
||||
weightsArrayU8,
|
||||
tokenizerArrayU8,
|
||||
mel_filtersArrayU8,
|
||||
configArrayU8,
|
||||
] = await Promise.all([
|
||||
fetchArrayBuffer(weightsURL),
|
||||
fetchArrayBuffer(tokenizerURL),
|
||||
fetchArrayBuffer(mel_filtersURL),
|
||||
fetchArrayBuffer(configURL),
|
||||
]);
|
||||
|
||||
this.instance[modelID] = new Decoder(
|
||||
weightsArrayU8,
|
||||
tokenizerArrayU8,
|
||||
mel_filtersArrayU8
|
||||
mel_filtersArrayU8,
|
||||
configArrayU8,
|
||||
quantized,
|
||||
is_multilingual,
|
||||
timestamps,
|
||||
task,
|
||||
language
|
||||
);
|
||||
} else {
|
||||
self.postMessage({ status: "loading", message: "Model Already Loaded" });
|
||||
@ -43,17 +66,37 @@ class Whisper {
|
||||
}
|
||||
|
||||
self.addEventListener("message", async (event) => {
|
||||
const { weightsURL, modelID, tokenizerURL, mel_filtersURL, audioURL } =
|
||||
event.data;
|
||||
const {
|
||||
weightsURL,
|
||||
modelID,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
mel_filtersURL,
|
||||
audioURL,
|
||||
} = event.data;
|
||||
try {
|
||||
self.postMessage({ status: "decoding", message: "Starting Decoder" });
|
||||
|
||||
const decoder = await Whisper.getInstance(
|
||||
let quantized = false;
|
||||
if (modelID.includes("quantized")) {
|
||||
quantized = true;
|
||||
}
|
||||
let is_multilingual = false;
|
||||
if (modelID.includes("multilingual")) {
|
||||
is_multilingual = true;
|
||||
}
|
||||
let timestamps = true;
|
||||
const decoder = await Whisper.getInstance({
|
||||
weightsURL,
|
||||
modelID,
|
||||
tokenizerURL,
|
||||
mel_filtersURL
|
||||
);
|
||||
mel_filtersURL,
|
||||
configURL,
|
||||
quantized,
|
||||
is_multilingual,
|
||||
timestamps,
|
||||
task: null,
|
||||
language: null,
|
||||
});
|
||||
|
||||
self.postMessage({ status: "decoding", message: "Loading Audio" });
|
||||
const audioArrayU8 = await fetchArrayBuffer(audioURL);
|
||||
|
Reference in New Issue
Block a user