mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Whisper quantized wasm (#1028)
* [Whisper] Update to use quantized model * [whisper] add language detection * [whisper] change assets location * [whisper] adapt js example with quantized models * [whisper] better task parsing * [whisper] minor fixes
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@ -29,9 +29,10 @@ trace-*.json
|
|||||||
candle-wasm-examples/*/build
|
candle-wasm-examples/*/build
|
||||||
candle-wasm-examples/*/*.bin
|
candle-wasm-examples/*/*.bin
|
||||||
candle-wasm-examples/*/*.jpeg
|
candle-wasm-examples/*/*.jpeg
|
||||||
candle-wasm-examples/*/*.wav
|
candle-wasm-examples/*/audios/*.wav
|
||||||
candle-wasm-examples/*/*.safetensors
|
candle-wasm-examples/**/*.safetensors
|
||||||
|
candle-wasm-examples/**/*.gguf
|
||||||
candle-wasm-examples/*/package-lock.json
|
candle-wasm-examples/*/package-lock.json
|
||||||
|
candle-wasm-examples/**/config*.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.idea/*
|
.idea/*
|
||||||
|
@ -11,6 +11,7 @@ license.workspace = true
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||||
|
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -10,19 +10,31 @@ From the `candle-wasm-examples/whisper` directory run:
|
|||||||
Download assets:
|
Download assets:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Model and tokenizer
|
# mel filters
|
||||||
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
|
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
|
||||||
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tiny.en.safetensors
|
# Model and tokenizer tiny.en
|
||||||
wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tokenizer.en.json
|
wget -c https://huggingface.co/openai/whisper-tiny.en/resolve/main/model.safetensors -P whisper-tiny.en
|
||||||
|
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/tokenizer.json -P whisper-tiny.en
|
||||||
|
wget -c https://huggingface.co/openai/whisper-tiny.en/raw/main/config.json -P whisper-tiny.en
|
||||||
|
# model and tokenizer tiny multilanguage
|
||||||
|
wget -c https://huggingface.co/openai/whisper-tiny/resolve/main/model.safetensors -P whisper-tiny
|
||||||
|
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/tokenizer.json -P whisper-tiny
|
||||||
|
wget -c https://huggingface.co/openai/whisper-tiny/raw/main/config.json -P whisper-tiny
|
||||||
|
|
||||||
|
#quantized
|
||||||
|
wget -c https://huggingface.co/lmz/candle-whisper/resolve/main/model-tiny-en-q80.gguf -P quantized
|
||||||
|
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/tokenizer-tiny-en.json -P quantized
|
||||||
|
wget -c https://huggingface.co/lmz/candle-whisper/raw/main/config-tiny-en.json -P quantized
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Audio samples
|
# Audio samples
|
||||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -O gb0.wav
|
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -P audios
|
||||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -O a13.wav
|
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -P audios
|
||||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -O gb1.wav
|
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -P audios
|
||||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -O hp0.wav
|
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -P audios
|
||||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -O jfk.wav
|
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -P audios
|
||||||
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -O mm0.wav
|
wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -P audios
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -3,22 +3,38 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<title>Welcome to Candle!</title>
|
<title>Welcome to Candle!</title>
|
||||||
|
|
||||||
<link data-trunk rel="copy-file" href="jfk.wav" />
|
|
||||||
<link data-trunk rel="copy-file" href="mm0.wav" />
|
|
||||||
<link data-trunk rel="copy-file" href="a13.wav" />
|
|
||||||
<link data-trunk rel="copy-file" href="gb0.wav" />
|
|
||||||
<link data-trunk rel="copy-file" href="gb1.wav" />
|
|
||||||
<link data-trunk rel="copy-file" href="hp0.wav" />
|
|
||||||
<link data-trunk rel="copy-file" href="tokenizer.en.json" />
|
|
||||||
<link data-trunk rel="copy-file" href="mel_filters.safetensors" />
|
<link data-trunk rel="copy-file" href="mel_filters.safetensors" />
|
||||||
<link data-trunk rel="copy-file" href="tiny.en.safetensors" />
|
<!-- samples -->
|
||||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
|
<link data-trunk rel="copy-dir" href="audios" />
|
||||||
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
|
<!-- tiny.en -->
|
||||||
|
<link data-trunk rel="copy-dir" href="whisper-tiny.en" />
|
||||||
|
<!-- tiny -->
|
||||||
|
<link data-trunk rel="copy-dir" href="whisper-tiny" />
|
||||||
|
<!-- quantized -->
|
||||||
|
<link data-trunk rel="copy-dir" href="quantized" />
|
||||||
|
|
||||||
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
|
<link
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
|
data-trunk
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
|
rel="rust"
|
||||||
|
href="Cargo.toml"
|
||||||
|
data-bin="app"
|
||||||
|
data-type="main" />
|
||||||
|
<link
|
||||||
|
data-trunk
|
||||||
|
rel="rust"
|
||||||
|
href="Cargo.toml"
|
||||||
|
data-bin="worker"
|
||||||
|
data-type="worker" />
|
||||||
|
|
||||||
|
<link
|
||||||
|
rel="stylesheet"
|
||||||
|
href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic" />
|
||||||
|
<link
|
||||||
|
rel="stylesheet"
|
||||||
|
href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css" />
|
||||||
|
<link
|
||||||
|
rel="stylesheet"
|
||||||
|
href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css" />
|
||||||
</head>
|
</head>
|
||||||
<body></body>
|
<body></body>
|
||||||
</html>
|
</html>
|
||||||
|
@ -26,9 +26,30 @@
|
|||||||
|
|
||||||
// models base url
|
// models base url
|
||||||
const MODELS = {
|
const MODELS = {
|
||||||
|
tiny_multilingual: {
|
||||||
|
base_url: "https://huggingface.co/openai/whisper-tiny/resolve/main/",
|
||||||
|
model: "model.safetensors",
|
||||||
|
tokenizer: "tokenizer.json",
|
||||||
|
config: "config.json",
|
||||||
|
},
|
||||||
tiny_en: {
|
tiny_en: {
|
||||||
base_url:
|
base_url:
|
||||||
"https://huggingface.co/openai/whisper-tiny.en/resolve/refs%2Fpr%2F17/",
|
"https://huggingface.co/openai/whisper-tiny.en/resolve/main/",
|
||||||
|
model: "model.safetensors",
|
||||||
|
tokenizer: "tokenizer.json",
|
||||||
|
config: "config.json",
|
||||||
|
},
|
||||||
|
tiny_quantized_multilingual_q80: {
|
||||||
|
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
|
||||||
|
model: "model-tiny-q80.gguf",
|
||||||
|
tokenizer: "tokenizer-tiny.json",
|
||||||
|
config: "config-tiny.json",
|
||||||
|
},
|
||||||
|
tiny_en_quantized_q80: {
|
||||||
|
base_url: "https://huggingface.co/lmz/candle-whisper/resolve/main/",
|
||||||
|
model: "model-tiny-q80.gguf",
|
||||||
|
tokenizer: "tokenizer-tiny-en.json",
|
||||||
|
config: "config-tiny-en.json",
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const whisperWorker = new Worker("./whisperWorker.js", {
|
const whisperWorker = new Worker("./whisperWorker.js", {
|
||||||
@ -39,6 +60,7 @@
|
|||||||
weightsURL, // URL to the weights file
|
weightsURL, // URL to the weights file
|
||||||
modelID, // model ID
|
modelID, // model ID
|
||||||
tokenizerURL, // URL to the tokenizer file
|
tokenizerURL, // URL to the tokenizer file
|
||||||
|
configURL, // model config URL
|
||||||
mel_filtersURL, // URL to the mel filters file
|
mel_filtersURL, // URL to the mel filters file
|
||||||
audioURL, // URL to the audio file
|
audioURL, // URL to the audio file
|
||||||
updateStatus // function to update the status
|
updateStatus // function to update the status
|
||||||
@ -48,6 +70,7 @@
|
|||||||
weightsURL,
|
weightsURL,
|
||||||
modelID,
|
modelID,
|
||||||
tokenizerURL,
|
tokenizerURL,
|
||||||
|
configURL,
|
||||||
mel_filtersURL,
|
mel_filtersURL,
|
||||||
audioURL,
|
audioURL,
|
||||||
});
|
});
|
||||||
@ -128,13 +151,16 @@
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const modelID = document.querySelector("#model").value;
|
const modelID = document.querySelector("#model").value;
|
||||||
const modelURL = MODELS[modelID].base_url + "model.safetensors";
|
const model = MODELS[modelID];
|
||||||
const tokenizerURL = MODELS[modelID].base_url + "tokenizer.json";
|
const modelURL = model.base_url + model.model;
|
||||||
|
const tokenizerURL = model.base_url + model.tokenizer;
|
||||||
|
const configURL = model.base_url + model.config;
|
||||||
|
|
||||||
classifyAudio(
|
classifyAudio(
|
||||||
modelURL,
|
modelURL,
|
||||||
modelID,
|
modelID,
|
||||||
tokenizerURL,
|
tokenizerURL,
|
||||||
|
configURL,
|
||||||
"mel_filters.safetensors",
|
"mel_filters.safetensors",
|
||||||
audioURL,
|
audioURL,
|
||||||
updateStatus
|
updateStatus
|
||||||
@ -178,8 +204,7 @@
|
|||||||
<a
|
<a
|
||||||
href="https://huggingface.co/openai/"
|
href="https://huggingface.co/openai/"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
class="underline hover:text-blue-500 hover:no-underline"
|
class="underline hover:text-blue-500 hover:no-underline">
|
||||||
>
|
|
||||||
OpenAI Whisper models
|
OpenAI Whisper models
|
||||||
</a>
|
</a>
|
||||||
and WASM runtime built with
|
and WASM runtime built with
|
||||||
@ -196,37 +221,38 @@
|
|||||||
<label for="model" class="font-medium">Models Options: </label>
|
<label for="model" class="font-medium">Models Options: </label>
|
||||||
<select
|
<select
|
||||||
id="model"
|
id="model"
|
||||||
class="border-2 border-gray-500 rounded-md font-light"
|
class="border-2 border-gray-500 rounded-md font-light">
|
||||||
>
|
<option value="tiny_multilingual" selected>tiny (151 MB)</option>
|
||||||
<option value="tiny_en" selected>tiny.en (151 MB)</option>
|
<option value="tiny_en" selected>tiny.en (151 MB)</option>
|
||||||
|
<option value="tiny_quantized_multilingual_q80">
|
||||||
|
tiny quantized q80 (41.5 MB)
|
||||||
|
</option>
|
||||||
|
<option value="tiny_en_quantized_q80">
|
||||||
|
tiny.en quantized q80 (41.8 MB)
|
||||||
|
</option>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
<!-- drag and drop area -->
|
<!-- drag and drop area -->
|
||||||
<div class="relative">
|
<div class="relative">
|
||||||
<div
|
<div
|
||||||
id="drop-area"
|
id="drop-area"
|
||||||
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden"
|
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden">
|
||||||
>
|
|
||||||
<div
|
<div
|
||||||
class="flex flex-col items-center justify-center space-y-1 text-center"
|
class="flex flex-col items-center justify-center space-y-1 text-center">
|
||||||
>
|
|
||||||
<svg
|
<svg
|
||||||
width="25"
|
width="25"
|
||||||
height="25"
|
height="25"
|
||||||
viewBox="0 0 25 25"
|
viewBox="0 0 25 25"
|
||||||
fill="none"
|
fill="none"
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg">
|
||||||
>
|
|
||||||
<path
|
<path
|
||||||
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
|
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
|
||||||
fill="#000"
|
fill="#000" />
|
||||||
/>
|
|
||||||
</svg>
|
</svg>
|
||||||
<div class="flex text-sm text-gray-600">
|
<div class="flex text-sm text-gray-600">
|
||||||
<label
|
<label
|
||||||
for="file-upload"
|
for="file-upload"
|
||||||
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
|
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700">
|
||||||
>
|
|
||||||
<span>Drag and drop your audio here</span>
|
<span>Drag and drop your audio here</span>
|
||||||
<span class="block text-xs">or</span>
|
<span class="block text-xs">or</span>
|
||||||
<span class="block text-xs">Click to upload</span>
|
<span class="block text-xs">Click to upload</span>
|
||||||
@ -237,15 +263,13 @@
|
|||||||
name="file-upload"
|
name="file-upload"
|
||||||
type="file"
|
type="file"
|
||||||
accept="audio/*"
|
accept="audio/*"
|
||||||
class="sr-only"
|
class="sr-only" />
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
<audio
|
<audio
|
||||||
id="audio"
|
id="audio"
|
||||||
hidden
|
hidden
|
||||||
controls
|
controls
|
||||||
class="w-full p-2 select-none"
|
class="w-full p-2 select-none"></audio>
|
||||||
></audio>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
@ -253,43 +277,37 @@
|
|||||||
<h3 class="font-medium">Examples:</h3>
|
<h3 class="font-medium">Examples:</h3>
|
||||||
<button
|
<button
|
||||||
data-value="samples_jfk.wav"
|
data-value="samples_jfk.wav"
|
||||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||||
>
|
|
||||||
<span>jfk.wav</span>
|
<span>jfk.wav</span>
|
||||||
<span class="text-xs block"> (352 kB)</span>
|
<span class="text-xs block"> (352 kB)</span>
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
data-value="samples_a13.wav"
|
data-value="samples_a13.wav"
|
||||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||||
>
|
|
||||||
<span>a13.wav</span>
|
<span>a13.wav</span>
|
||||||
<span class="text-xs block"> (960 kB)</span>
|
<span class="text-xs block"> (960 kB)</span>
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
data-value="samples_mm0.wav"
|
data-value="samples_mm0.wav"
|
||||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||||
>
|
|
||||||
<span>mm0.wav</span>
|
<span>mm0.wav</span>
|
||||||
<span class="text-xs block new"> (957 kB)</span>
|
<span class="text-xs block new"> (957 kB)</span>
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
data-value="samples_gb0.wav"
|
data-value="samples_gb0.wav"
|
||||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||||
>
|
|
||||||
<span>gb0.wav </span>
|
<span>gb0.wav </span>
|
||||||
<span class="text-xs block">(4.08 MB)</span>
|
<span class="text-xs block">(4.08 MB)</span>
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
data-value="samples_gb1.wav"
|
data-value="samples_gb1.wav"
|
||||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||||
>
|
|
||||||
<span>gb1.wav </span>
|
<span>gb1.wav </span>
|
||||||
<span class="text-xs block">(6.36 MB)</span>
|
<span class="text-xs block">(6.36 MB)</span>
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
data-value="samples_hp0.wav"
|
data-value="samples_hp0.wav"
|
||||||
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
|
class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline">
|
||||||
>
|
|
||||||
<span>hp0.wav </span>
|
<span>hp0.wav </span>
|
||||||
<span class="text-xs block">(8.75 MB)</span>
|
<span class="text-xs block">(8.75 MB)</span>
|
||||||
</button>
|
</button>
|
||||||
@ -300,16 +318,14 @@
|
|||||||
<button
|
<button
|
||||||
id="detect"
|
id="detect"
|
||||||
disabled
|
disabled
|
||||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
|
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 px-4 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
|
||||||
>
|
|
||||||
Transcribe Audio
|
Transcribe Audio
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<h3 class="font-medium">Transcription:</h3>
|
<h3 class="font-medium">Transcription:</h3>
|
||||||
<div
|
<div
|
||||||
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
|
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
|
||||||
>
|
|
||||||
<p hidden id="output-generation" class="grid-rows-2"></p>
|
<p hidden id="output-generation" class="grid-rows-2"></p>
|
||||||
<span id="output-status" class="m-auto font-light"
|
<span id="output-status" class="m-auto font-light"
|
||||||
>No transcription results yet</span
|
>No transcription results yet</span
|
||||||
|
@ -7,7 +7,12 @@ use yew::{html, Component, Context, Html};
|
|||||||
use yew_agent::{Bridge, Bridged};
|
use yew_agent::{Bridge, Bridged};
|
||||||
|
|
||||||
const SAMPLE_NAMES: [&str; 6] = [
|
const SAMPLE_NAMES: [&str; 6] = [
|
||||||
"jfk.wav", "a13.wav", "gb0.wav", "gb1.wav", "hp0.wav", "mm0.wav",
|
"audios/samples_jfk.wav",
|
||||||
|
"audios/samples_a13.wav",
|
||||||
|
"audios/samples_gb0.wav",
|
||||||
|
"audios/samples_gb1.wav",
|
||||||
|
"audios/samples_hp0.wav",
|
||||||
|
"audios/samples_mm0.wav",
|
||||||
];
|
];
|
||||||
|
|
||||||
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
|
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
|
||||||
@ -54,14 +59,46 @@ pub struct App {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn model_data_load() -> Result<ModelData, JsValue> {
|
async fn model_data_load() -> Result<ModelData, JsValue> {
|
||||||
let tokenizer = fetch_url("tokenizer.en.json").await?;
|
let quantized = false;
|
||||||
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
let is_multilingual = false;
|
||||||
let weights = fetch_url("tiny.en.safetensors").await?;
|
|
||||||
|
let (tokenizer, mel_filters, weights, config) = if quantized {
|
||||||
|
console_log!("loading quantized weights");
|
||||||
|
let tokenizer = fetch_url("quantized/tokenizer-tiny-en.json").await?;
|
||||||
|
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
||||||
|
let weights = fetch_url("quantized/model-tiny-en-q80.gguf").await?;
|
||||||
|
let config = fetch_url("quantized/config-tiny-en.json").await?;
|
||||||
|
(tokenizer, mel_filters, weights, config)
|
||||||
|
} else {
|
||||||
|
console_log!("loading float weights");
|
||||||
|
if is_multilingual {
|
||||||
|
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
||||||
|
let tokenizer = fetch_url("whisper-tiny/tokenizer.json").await?;
|
||||||
|
let weights = fetch_url("whisper-tiny/model.safetensors").await?;
|
||||||
|
let config = fetch_url("whisper-tiny/config.json").await?;
|
||||||
|
(tokenizer, mel_filters, weights, config)
|
||||||
|
} else {
|
||||||
|
let mel_filters = fetch_url("mel_filters.safetensors").await?;
|
||||||
|
let tokenizer = fetch_url("whisper-tiny.en/tokenizer.json").await?;
|
||||||
|
let weights = fetch_url("whisper-tiny.en/model.safetensors").await?;
|
||||||
|
let config = fetch_url("whisper-tiny.en/config.json").await?;
|
||||||
|
(tokenizer, mel_filters, weights, config)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let timestamps = true;
|
||||||
|
let _task = Some("transcribe".to_string());
|
||||||
console_log!("{}", weights.len());
|
console_log!("{}", weights.len());
|
||||||
Ok(ModelData {
|
Ok(ModelData {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
mel_filters,
|
mel_filters,
|
||||||
weights,
|
weights,
|
||||||
|
config,
|
||||||
|
quantized,
|
||||||
|
timestamps,
|
||||||
|
task: None,
|
||||||
|
is_multilingual,
|
||||||
|
language: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
|||||||
let n_len = samples.len() / fft_step;
|
let n_len = samples.len() / fft_step;
|
||||||
|
|
||||||
// pad audio with at least one extra chunk of zeros
|
// pad audio with at least one extra chunk of zeros
|
||||||
let pad = 100 * worker::CHUNK_LENGTH / 2;
|
let pad = 100 * worker::m::CHUNK_LENGTH / 2;
|
||||||
let n_len = if n_len % pad != 0 {
|
let n_len = if n_len % pad != 0 {
|
||||||
(n_len / pad + 1) * pad
|
(n_len / pad + 1) * pad
|
||||||
} else {
|
} else {
|
||||||
@ -206,9 +206,9 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
|||||||
let mel = log_mel_spectrogram_(
|
let mel = log_mel_spectrogram_(
|
||||||
samples,
|
samples,
|
||||||
filters,
|
filters,
|
||||||
worker::N_FFT,
|
worker::m::N_FFT,
|
||||||
worker::HOP_LENGTH,
|
worker::m::HOP_LENGTH,
|
||||||
worker::N_MELS,
|
worker::m::N_MELS,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
Ok(mel)
|
Ok(mel)
|
||||||
|
@ -9,15 +9,28 @@ pub struct Decoder {
|
|||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
impl Decoder {
|
impl Decoder {
|
||||||
#[wasm_bindgen(constructor)]
|
#[wasm_bindgen(constructor)]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
weights: Vec<u8>,
|
weights: Vec<u8>,
|
||||||
tokenizer: Vec<u8>,
|
tokenizer: Vec<u8>,
|
||||||
mel_filters: Vec<u8>,
|
mel_filters: Vec<u8>,
|
||||||
|
config: Vec<u8>,
|
||||||
|
quantized: bool,
|
||||||
|
is_multilingual: bool,
|
||||||
|
timestamps: bool,
|
||||||
|
task: Option<String>,
|
||||||
|
language: Option<String>,
|
||||||
) -> Result<Decoder, JsError> {
|
) -> Result<Decoder, JsError> {
|
||||||
let decoder = D::load(ModelData {
|
let decoder = D::load(ModelData {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
mel_filters,
|
mel_filters,
|
||||||
|
config,
|
||||||
|
quantized,
|
||||||
weights,
|
weights,
|
||||||
|
is_multilingual,
|
||||||
|
timestamps,
|
||||||
|
task,
|
||||||
|
language,
|
||||||
});
|
});
|
||||||
|
|
||||||
match decoder {
|
match decoder {
|
||||||
@ -32,7 +45,6 @@ impl Decoder {
|
|||||||
.decoder
|
.decoder
|
||||||
.convert_and_run(&wav_input)
|
.convert_and_run(&wav_input)
|
||||||
.map_err(|e| JsError::new(&e.to_string()))?;
|
.map_err(|e| JsError::new(&e.to_string()))?;
|
||||||
|
|
||||||
let json = serde_json::to_string(&segments)?;
|
let json = serde_json::to_string(&segments)?;
|
||||||
Ok(json)
|
Ok(json)
|
||||||
}
|
}
|
||||||
|
101
candle-wasm-examples/whisper/src/languages.rs
Normal file
101
candle-wasm-examples/whisper/src/languages.rs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
pub const LANGUAGES: [(&str, &str); 99] = [
|
||||||
|
("en", "english"),
|
||||||
|
("zh", "chinese"),
|
||||||
|
("de", "german"),
|
||||||
|
("es", "spanish"),
|
||||||
|
("ru", "russian"),
|
||||||
|
("ko", "korean"),
|
||||||
|
("fr", "french"),
|
||||||
|
("ja", "japanese"),
|
||||||
|
("pt", "portuguese"),
|
||||||
|
("tr", "turkish"),
|
||||||
|
("pl", "polish"),
|
||||||
|
("ca", "catalan"),
|
||||||
|
("nl", "dutch"),
|
||||||
|
("ar", "arabic"),
|
||||||
|
("sv", "swedish"),
|
||||||
|
("it", "italian"),
|
||||||
|
("id", "indonesian"),
|
||||||
|
("hi", "hindi"),
|
||||||
|
("fi", "finnish"),
|
||||||
|
("vi", "vietnamese"),
|
||||||
|
("he", "hebrew"),
|
||||||
|
("uk", "ukrainian"),
|
||||||
|
("el", "greek"),
|
||||||
|
("ms", "malay"),
|
||||||
|
("cs", "czech"),
|
||||||
|
("ro", "romanian"),
|
||||||
|
("da", "danish"),
|
||||||
|
("hu", "hungarian"),
|
||||||
|
("ta", "tamil"),
|
||||||
|
("no", "norwegian"),
|
||||||
|
("th", "thai"),
|
||||||
|
("ur", "urdu"),
|
||||||
|
("hr", "croatian"),
|
||||||
|
("bg", "bulgarian"),
|
||||||
|
("lt", "lithuanian"),
|
||||||
|
("la", "latin"),
|
||||||
|
("mi", "maori"),
|
||||||
|
("ml", "malayalam"),
|
||||||
|
("cy", "welsh"),
|
||||||
|
("sk", "slovak"),
|
||||||
|
("te", "telugu"),
|
||||||
|
("fa", "persian"),
|
||||||
|
("lv", "latvian"),
|
||||||
|
("bn", "bengali"),
|
||||||
|
("sr", "serbian"),
|
||||||
|
("az", "azerbaijani"),
|
||||||
|
("sl", "slovenian"),
|
||||||
|
("kn", "kannada"),
|
||||||
|
("et", "estonian"),
|
||||||
|
("mk", "macedonian"),
|
||||||
|
("br", "breton"),
|
||||||
|
("eu", "basque"),
|
||||||
|
("is", "icelandic"),
|
||||||
|
("hy", "armenian"),
|
||||||
|
("ne", "nepali"),
|
||||||
|
("mn", "mongolian"),
|
||||||
|
("bs", "bosnian"),
|
||||||
|
("kk", "kazakh"),
|
||||||
|
("sq", "albanian"),
|
||||||
|
("sw", "swahili"),
|
||||||
|
("gl", "galician"),
|
||||||
|
("mr", "marathi"),
|
||||||
|
("pa", "punjabi"),
|
||||||
|
("si", "sinhala"),
|
||||||
|
("km", "khmer"),
|
||||||
|
("sn", "shona"),
|
||||||
|
("yo", "yoruba"),
|
||||||
|
("so", "somali"),
|
||||||
|
("af", "afrikaans"),
|
||||||
|
("oc", "occitan"),
|
||||||
|
("ka", "georgian"),
|
||||||
|
("be", "belarusian"),
|
||||||
|
("tg", "tajik"),
|
||||||
|
("sd", "sindhi"),
|
||||||
|
("gu", "gujarati"),
|
||||||
|
("am", "amharic"),
|
||||||
|
("yi", "yiddish"),
|
||||||
|
("lo", "lao"),
|
||||||
|
("uz", "uzbek"),
|
||||||
|
("fo", "faroese"),
|
||||||
|
("ht", "haitian creole"),
|
||||||
|
("ps", "pashto"),
|
||||||
|
("tk", "turkmen"),
|
||||||
|
("nn", "nynorsk"),
|
||||||
|
("mt", "maltese"),
|
||||||
|
("sa", "sanskrit"),
|
||||||
|
("lb", "luxembourgish"),
|
||||||
|
("my", "myanmar"),
|
||||||
|
("bo", "tibetan"),
|
||||||
|
("tl", "tagalog"),
|
||||||
|
("mg", "malagasy"),
|
||||||
|
("as", "assamese"),
|
||||||
|
("tt", "tatar"),
|
||||||
|
("haw", "hawaiian"),
|
||||||
|
("ln", "lingala"),
|
||||||
|
("ha", "hausa"),
|
||||||
|
("ba", "bashkir"),
|
||||||
|
("jw", "javanese"),
|
||||||
|
("su", "sundanese"),
|
||||||
|
];
|
@ -4,14 +4,14 @@ struct Timer {
|
|||||||
label: &'static str,
|
label: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Timer {
|
// impl Timer {
|
||||||
fn new(label: &'static str) -> Self {
|
// fn new(label: &'static str) -> Self {
|
||||||
if WITH_TIMER {
|
// if WITH_TIMER {
|
||||||
web_sys::console::time_with_label(label);
|
// web_sys::console::time_with_label(label);
|
||||||
}
|
// }
|
||||||
Self { label }
|
// Self { label }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
impl Drop for Timer {
|
impl Drop for Timer {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
@ -23,7 +23,7 @@ impl Drop for Timer {
|
|||||||
|
|
||||||
mod app;
|
mod app;
|
||||||
mod audio;
|
mod audio;
|
||||||
mod model;
|
pub mod languages;
|
||||||
pub mod worker;
|
pub mod worker;
|
||||||
pub use app::App;
|
pub use app::App;
|
||||||
pub use worker::Worker;
|
pub use worker::Worker;
|
||||||
|
@ -1,417 +0,0 @@
|
|||||||
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
|
||||||
// back when using RUST_LIB_BACKTRACE=1.
|
|
||||||
use anyhow::Result;
|
|
||||||
use candle::{Device, Tensor};
|
|
||||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
// The names in comments correspond to the original implementation:
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
pub num_mel_bins: usize, // n_mels
|
|
||||||
pub max_source_positions: usize, // n_audio_ctx
|
|
||||||
pub d_model: usize, // n_audio_state
|
|
||||||
pub encoder_attention_heads: usize, // n_audio_head
|
|
||||||
pub encoder_layers: usize, // n_audio_layer
|
|
||||||
pub vocab_size: usize, // n_vocab
|
|
||||||
pub max_target_positions: usize, // n_text_ctx
|
|
||||||
// pub n_text_state: usize,
|
|
||||||
pub decoder_attention_heads: usize, // n_text_head
|
|
||||||
pub decoder_layers: usize, // n_text_layer
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
pub fn tiny_en() -> Self {
|
|
||||||
Self {
|
|
||||||
num_mel_bins: 80,
|
|
||||||
vocab_size: 51864,
|
|
||||||
max_source_positions: 1500,
|
|
||||||
d_model: 384,
|
|
||||||
encoder_attention_heads: 6,
|
|
||||||
encoder_layers: 4,
|
|
||||||
max_target_positions: 448,
|
|
||||||
// n_text_state: 384,
|
|
||||||
decoder_attention_heads: 6,
|
|
||||||
decoder_layers: 4,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm
|
|
||||||
// specific monitoring.
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Linear {
|
|
||||||
weight: Tensor,
|
|
||||||
bias: Option<Tensor>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
|
||||||
Self { weight, bias }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
let _timer = crate::Timer::new("Linear::forward");
|
|
||||||
let w = match x.dims() {
|
|
||||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
|
||||||
_ => self.weight.t()?,
|
|
||||||
};
|
|
||||||
let x = {
|
|
||||||
let _timer = crate::Timer::new("Linear::matmul");
|
|
||||||
x.matmul(&w)?
|
|
||||||
};
|
|
||||||
match &self.bias {
|
|
||||||
None => Ok(x),
|
|
||||||
Some(bias) => x.broadcast_add(bias),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
let bias = vb.get(size2, "bias")?;
|
|
||||||
Ok(Linear::new(weight, Some(bias)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
Ok(Linear::new(weight, None))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv1d(
|
|
||||||
in_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
config: Conv1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
|
||||||
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
|
||||||
let bias = vb.get(out_channels, "bias")?;
|
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
|
||||||
let weight = vb.get(size, "weight")?;
|
|
||||||
let bias = vb.get(size, "bias")?;
|
|
||||||
Ok(LayerNorm::new(weight, bias, 1e-5))
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
|
||||||
struct MultiHeadAttention {
|
|
||||||
query: Linear,
|
|
||||||
key: Linear,
|
|
||||||
value: Linear,
|
|
||||||
out: Linear,
|
|
||||||
n_head: usize,
|
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MultiHeadAttention {
|
|
||||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
|
||||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
|
||||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
|
||||||
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
out,
|
|
||||||
n_head,
|
|
||||||
kv_cache: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
x: &Tensor,
|
|
||||||
xa: Option<&Tensor>,
|
|
||||||
mask: Option<&Tensor>,
|
|
||||||
flush_cache: bool,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let _timer = crate::Timer::new("MultiHeadAttention::forward");
|
|
||||||
let q = self.query.forward(x)?;
|
|
||||||
let (k, v) = match xa {
|
|
||||||
None => {
|
|
||||||
let k = self.key.forward(x)?;
|
|
||||||
let v = self.value.forward(x)?;
|
|
||||||
(k, v)
|
|
||||||
}
|
|
||||||
Some(x) => {
|
|
||||||
if flush_cache {
|
|
||||||
self.kv_cache = None;
|
|
||||||
}
|
|
||||||
if let Some((k, v)) = &self.kv_cache {
|
|
||||||
(k.clone(), v.clone())
|
|
||||||
} else {
|
|
||||||
let k = self.key.forward(x)?;
|
|
||||||
let v = self.value.forward(x)?;
|
|
||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
|
||||||
(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
|
||||||
let out = self.out.forward(&wv)?;
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
|
||||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
|
||||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn qkv_attention(
|
|
||||||
&self,
|
|
||||||
q: &Tensor,
|
|
||||||
k: &Tensor,
|
|
||||||
v: &Tensor,
|
|
||||||
mask: Option<&Tensor>,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let (_, n_ctx, n_state) = q.dims3()?;
|
|
||||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
|
||||||
let q = (self.reshape_head(q)? * scale)?;
|
|
||||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
|
||||||
let v = self.reshape_head(v)?.contiguous()?;
|
|
||||||
let mut qk = {
|
|
||||||
let _timer = crate::Timer::new("qk::matmul");
|
|
||||||
q.matmul(&k)?
|
|
||||||
};
|
|
||||||
if let Some(mask) = mask {
|
|
||||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
|
||||||
qk = qk.broadcast_add(&mask)?
|
|
||||||
}
|
|
||||||
let w = {
|
|
||||||
let _timer = crate::Timer::new("qk::softmax");
|
|
||||||
candle_nn::ops::softmax(&qk, candle::D::Minus1)?
|
|
||||||
};
|
|
||||||
let wv = {
|
|
||||||
let _timer = crate::Timer::new("wv::matmul");
|
|
||||||
w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?
|
|
||||||
};
|
|
||||||
Ok(wv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
|
||||||
struct ResidualAttentionBlock {
|
|
||||||
attn: MultiHeadAttention,
|
|
||||||
attn_ln: LayerNorm,
|
|
||||||
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
|
||||||
mlp_linear1: Linear,
|
|
||||||
mlp_linear2: Linear,
|
|
||||||
mlp_ln: LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResidualAttentionBlock {
|
|
||||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
|
||||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
|
||||||
let cross_attn = if ca {
|
|
||||||
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
|
||||||
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
|
|
||||||
Some((cross_attn, cross_attn_ln))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let n_mlp = n_state * 4;
|
|
||||||
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
|
||||||
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
|
||||||
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
attn,
|
|
||||||
attn_ln,
|
|
||||||
cross_attn,
|
|
||||||
mlp_linear1,
|
|
||||||
mlp_linear2,
|
|
||||||
mlp_ln,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
x: &Tensor,
|
|
||||||
xa: Option<&Tensor>,
|
|
||||||
mask: Option<&Tensor>,
|
|
||||||
flush_kv_cache: bool,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
|
|
||||||
let attn = self
|
|
||||||
.attn
|
|
||||||
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
|
||||||
let mut x = (x + attn)?;
|
|
||||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
|
||||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
|
||||||
}
|
|
||||||
let mlp = self.mlp_linear2.forward(
|
|
||||||
&self
|
|
||||||
.mlp_linear1
|
|
||||||
.forward(&self.mlp_ln.forward(&x)?)?
|
|
||||||
.gelu()?,
|
|
||||||
)?;
|
|
||||||
Ok((x + mlp)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
|
||||||
let max_timescale = 10000f32;
|
|
||||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
|
||||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
|
||||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
|
||||||
.collect();
|
|
||||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
|
||||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
|
||||||
.to_dtype(candle::DType::F32)?
|
|
||||||
.unsqueeze(1)?;
|
|
||||||
let sh = (length, channels / 2);
|
|
||||||
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
|
||||||
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
|
||||||
Ok(sincos)
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
|
||||||
pub struct AudioEncoder {
|
|
||||||
conv1: Conv1d,
|
|
||||||
conv2: Conv1d,
|
|
||||||
positional_embedding: Tensor,
|
|
||||||
blocks: Vec<ResidualAttentionBlock>,
|
|
||||||
ln_post: LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AudioEncoder {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let n_state = cfg.d_model;
|
|
||||||
let n_head = cfg.encoder_attention_heads;
|
|
||||||
let n_ctx = cfg.max_source_positions;
|
|
||||||
let cfg1 = Conv1dConfig {
|
|
||||||
padding: 1,
|
|
||||||
stride: 1,
|
|
||||||
groups: 1,
|
|
||||||
dilation: 1,
|
|
||||||
};
|
|
||||||
let cfg2 = Conv1dConfig {
|
|
||||||
padding: 1,
|
|
||||||
stride: 2,
|
|
||||||
groups: 1,
|
|
||||||
dilation: 1,
|
|
||||||
};
|
|
||||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
|
||||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
|
||||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
|
||||||
let blocks = (0..cfg.encoder_layers)
|
|
||||||
.map(|i| {
|
|
||||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
conv1,
|
|
||||||
conv2,
|
|
||||||
positional_embedding,
|
|
||||||
blocks,
|
|
||||||
ln_post,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
|
||||||
let _timer = crate::Timer::new("AudioEncoder::forward");
|
|
||||||
let x = {
|
|
||||||
let _timer = crate::Timer::new("conv1::forward");
|
|
||||||
self.conv1.forward(x)?.gelu()?
|
|
||||||
};
|
|
||||||
let x = {
|
|
||||||
let _timer = crate::Timer::new("conv2::forward");
|
|
||||||
self.conv2.forward(&x)?.gelu()?
|
|
||||||
};
|
|
||||||
let x = x.transpose(1, 2)?;
|
|
||||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
|
||||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
|
||||||
for block in self.blocks.iter_mut() {
|
|
||||||
x = block.forward(&x, None, None, flush_kv_cache)?
|
|
||||||
}
|
|
||||||
let x = self.ln_post.forward(&x)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
|
||||||
pub struct TextDecoder {
|
|
||||||
token_embedding: Embedding,
|
|
||||||
positional_embedding: Tensor,
|
|
||||||
blocks: Vec<ResidualAttentionBlock>,
|
|
||||||
ln: LayerNorm,
|
|
||||||
mask: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextDecoder {
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
|
||||||
let _timer = crate::Timer::new("TextDecoder::forward");
|
|
||||||
let n_state = cfg.d_model;
|
|
||||||
let n_head = cfg.decoder_attention_heads;
|
|
||||||
let n_ctx = cfg.max_target_positions;
|
|
||||||
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
|
|
||||||
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
|
|
||||||
let blocks = (0..cfg.decoder_layers)
|
|
||||||
.map(|i| {
|
|
||||||
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
|
||||||
let mask: Vec<_> = (0..n_ctx)
|
|
||||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
|
||||||
.collect();
|
|
||||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
token_embedding,
|
|
||||||
positional_embedding,
|
|
||||||
blocks,
|
|
||||||
ln,
|
|
||||||
mask,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
|
||||||
let x_dims = x.dims();
|
|
||||||
let last = x_dims[x_dims.len() - 1];
|
|
||||||
let token_embedding = self.token_embedding.forward(x)?;
|
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
|
||||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
|
||||||
for block in self.blocks.iter_mut() {
|
|
||||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
|
||||||
}
|
|
||||||
let x = self.ln.forward(&x)?;
|
|
||||||
let w = self
|
|
||||||
.token_embedding
|
|
||||||
.embeddings()
|
|
||||||
.broadcast_left(x_dims[0])?;
|
|
||||||
let logits = x.matmul(&w.t()?)?;
|
|
||||||
Ok(logits)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
|
||||||
pub struct Whisper {
|
|
||||||
pub encoder: AudioEncoder,
|
|
||||||
pub decoder: TextDecoder,
|
|
||||||
pub config: Config,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Whisper {
|
|
||||||
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
|
||||||
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
|
|
||||||
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
|
|
||||||
Ok(Self {
|
|
||||||
encoder,
|
|
||||||
decoder,
|
|
||||||
config,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,7 +1,8 @@
|
|||||||
use crate::model::{Config, Whisper};
|
use crate::languages::LANGUAGES;
|
||||||
use anyhow::Error as E;
|
use anyhow::Error as E;
|
||||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
|
||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
|
pub use candle_transformers::models::whisper::{self as m, Config};
|
||||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -25,38 +26,46 @@ macro_rules! console_log {
|
|||||||
|
|
||||||
pub const DTYPE: DType = DType::F32;
|
pub const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
// Audio parameters.
|
pub enum Model {
|
||||||
pub const SAMPLE_RATE: usize = 16000;
|
Normal(m::model::Whisper),
|
||||||
pub const N_FFT: usize = 400;
|
Quantized(m::quantized_model::Whisper),
|
||||||
pub const N_MELS: usize = 80;
|
}
|
||||||
pub const HOP_LENGTH: usize = 160;
|
|
||||||
pub const CHUNK_LENGTH: usize = 30;
|
|
||||||
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
|
||||||
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
|
||||||
|
|
||||||
pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
// Maybe we should use some traits rather than doing the dispatch for all these.
|
||||||
pub const LOGPROB_THRESHOLD: f64 = -1.0;
|
impl Model {
|
||||||
pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
pub fn config(&self) -> &Config {
|
||||||
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
match self {
|
||||||
|
Self::Normal(m) => &m.config,
|
||||||
|
Self::Quantized(m) => &m.config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Tokenizer dependent bits.
|
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
||||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
match self {
|
||||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
Self::Normal(m) => m.encoder.forward(x, flush),
|
||||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
Self::Quantized(m) => m.encoder.forward(x, flush),
|
||||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
}
|
||||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
}
|
||||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
|
||||||
|
|
||||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
pub fn decoder_forward(
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
&mut self,
|
||||||
pub const SUPPRESS_TOKENS: [u32; 91] = [
|
x: &Tensor,
|
||||||
1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
|
xa: &Tensor,
|
||||||
366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
|
flush: bool,
|
||||||
1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
|
) -> candle::Result<Tensor> {
|
||||||
10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
|
match self {
|
||||||
19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
|
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
||||||
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
||||||
];
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Normal(m) => m.decoder.final_linear(x),
|
||||||
|
Self::Quantized(m) => m.decoder.final_linear(x),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct DecodingResult {
|
pub struct DecodingResult {
|
||||||
@ -77,8 +86,13 @@ pub struct Segment {
|
|||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub struct Decoder {
|
pub struct Decoder {
|
||||||
model: Whisper,
|
model: Model,
|
||||||
|
rng: rand::rngs::StdRng,
|
||||||
|
task: Option<Task>,
|
||||||
|
language: Option<String>,
|
||||||
|
is_multilingual: bool,
|
||||||
mel_filters: Vec<f32>,
|
mel_filters: Vec<f32>,
|
||||||
|
timestamps: bool,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
suppress_tokens: Tensor,
|
suppress_tokens: Tensor,
|
||||||
sot_token: u32,
|
sot_token: u32,
|
||||||
@ -90,32 +104,43 @@ pub struct Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Decoder {
|
impl Decoder {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
fn new(
|
||||||
model: Whisper,
|
model: Model,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
mel_filters: Vec<f32>,
|
mel_filters: Vec<f32>,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
task: Option<Task>,
|
||||||
|
language: Option<String>,
|
||||||
|
is_multilingual: bool,
|
||||||
|
timestamps: bool,
|
||||||
) -> anyhow::Result<Self> {
|
) -> anyhow::Result<Self> {
|
||||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
if SUPPRESS_TOKENS.contains(&i) {
|
if model.config().suppress_tokens.contains(&i) {
|
||||||
f32::NEG_INFINITY
|
f32::NEG_INFINITY
|
||||||
} else {
|
} else {
|
||||||
0f32
|
0f32
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||||
|
let seed = 299792458;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
mel_filters,
|
rng: StdRng::seed_from_u64(seed),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
mel_filters,
|
||||||
|
task,
|
||||||
|
timestamps,
|
||||||
|
language,
|
||||||
|
is_multilingual,
|
||||||
suppress_tokens,
|
suppress_tokens,
|
||||||
sot_token,
|
sot_token,
|
||||||
transcribe_token,
|
transcribe_token,
|
||||||
@ -126,40 +151,73 @@ impl Decoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode(&mut self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
|
fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
|
||||||
let model = &mut self.model;
|
let model = &mut self.model;
|
||||||
let audio_features = model.encoder.forward(mel, true)?;
|
let language_token = match (self.is_multilingual, &self.language) {
|
||||||
console_log!("audio features: {:?}", audio_features.dims());
|
(true, None) => Some(detect_language(model, &self.tokenizer, mel)?),
|
||||||
let sample_len = model.config.max_target_positions / 2;
|
(false, None) => None,
|
||||||
|
(true, Some(language)) => {
|
||||||
|
match token_id(&self.tokenizer, &format!("<|{:?}|>", self.language)) {
|
||||||
|
Ok(token_id) => Some(token_id),
|
||||||
|
Err(_) => anyhow::bail!("language {language} is not supported"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(false, Some(_)) => {
|
||||||
|
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let audio_features = model.encoder_forward(mel, true)?;
|
||||||
|
println!("audio features: {:?}", audio_features.dims());
|
||||||
|
let sample_len = model.config().max_target_positions / 2;
|
||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
let mut no_speech_prob = f64::NAN;
|
let mut no_speech_prob = f64::NAN;
|
||||||
let mut tokens = vec![self.sot_token, self.transcribe_token];
|
let mut tokens = vec![self.sot_token];
|
||||||
|
if let Some(language_token) = language_token {
|
||||||
|
tokens.push(language_token);
|
||||||
|
}
|
||||||
|
match self.task {
|
||||||
|
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||||
|
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||||
|
}
|
||||||
|
if !self.timestamps {
|
||||||
|
tokens.push(self.no_timestamps_token);
|
||||||
|
}
|
||||||
for i in 0..sample_len {
|
for i in 0..sample_len {
|
||||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||||
|
|
||||||
// The model expects a batch dim but this inference loop does not handle
|
// The model expects a batch dim but this inference loop does not handle
|
||||||
// it so we add it at this point.
|
// it so we add it at this point.
|
||||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||||
let logits = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
|
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
|
|
||||||
// Extract the no speech probability on the first iteration by looking at the first
|
// Extract the no speech probability on the first iteration by looking at the first
|
||||||
// token logits and the probability for the according token.
|
// token logits and the probability for the according token.
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
no_speech_prob = softmax(&logits.get(0)?, 0)?
|
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||||
.get(self.no_speech_token as usize)?
|
no_speech_prob = softmax(&logits, 0)?
|
||||||
|
.i(self.no_speech_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (seq_len, _) = logits.dims2()?;
|
let (_, seq_len, _) = ys.dims3()?;
|
||||||
let logits = logits
|
let logits = model
|
||||||
.get(seq_len - 1)?
|
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
||||||
.broadcast_add(&self.suppress_tokens)?;
|
.i(0)?
|
||||||
|
.i(0)?;
|
||||||
|
// TODO: Besides suppress tokens, we should apply the heuristics from
|
||||||
|
// ApplyTimestampRules, i.e.:
|
||||||
|
// - Timestamps come in pairs, except before EOT.
|
||||||
|
// - Timestamps should be non-decreasing.
|
||||||
|
// - If the sum of the probabilities of timestamps is higher than any other tokens,
|
||||||
|
// only consider timestamps when sampling.
|
||||||
|
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
|
||||||
|
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
||||||
let next_token = if t > 0f64 {
|
let next_token = if t > 0f64 {
|
||||||
let prs = softmax(&(&logits / t)?, 0)?;
|
let prs = softmax(&(&logits / t)?, 0)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||||
distr.sample(rng) as u32
|
distr.sample(&mut self.rng) as u32
|
||||||
} else {
|
} else {
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
logits_v
|
logits_v
|
||||||
@ -171,9 +229,9 @@ impl Decoder {
|
|||||||
};
|
};
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
let prob = softmax(&logits, candle::D::Minus1)?
|
let prob = softmax(&logits, candle::D::Minus1)?
|
||||||
.get(next_token as usize)?
|
.i(next_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
|
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sum_logprob += prob.ln();
|
sum_logprob += prob.ln();
|
||||||
@ -191,22 +249,18 @@ impl Decoder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode_with_fallback(
|
fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
|
||||||
&mut self,
|
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||||
segment: &Tensor,
|
let dr: Result<DecodingResult, _> = self.decode(segment, t);
|
||||||
rng: &mut StdRng,
|
if i == m::TEMPERATURES.len() - 1 {
|
||||||
) -> anyhow::Result<DecodingResult> {
|
|
||||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
|
||||||
let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
|
|
||||||
if i == TEMPERATURES.len() - 1 {
|
|
||||||
return dr;
|
return dr;
|
||||||
}
|
}
|
||||||
// On errors, we try again with a different temperature.
|
// On errors, we try again with a different temperature.
|
||||||
match dr {
|
match dr {
|
||||||
Ok(dr) => {
|
Ok(dr) => {
|
||||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||||
return Ok(dr);
|
return Ok(dr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,18 +273,17 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
fn run(&mut self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
|
||||||
let (_, _, content_frames) = mel.dims3()?;
|
let (_, _, content_frames) = mel.dims3()?;
|
||||||
let mut seek = 0;
|
let mut seek = 0;
|
||||||
let mut segments = vec![];
|
let mut segments = vec![];
|
||||||
while seek < content_frames {
|
while seek < content_frames {
|
||||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||||
let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
|
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||||
seek += segment_size;
|
seek += segment_size;
|
||||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||||
console_log!("no speech detected, skipping {seek} {dr:?}");
|
console_log!("no speech detected, skipping {seek} {dr:?}");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -247,17 +300,39 @@ impl Decoder {
|
|||||||
|
|
||||||
pub fn load(md: ModelData) -> anyhow::Result<Self> {
|
pub fn load(md: ModelData) -> anyhow::Result<Self> {
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
|
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(E::msg)?;
|
||||||
|
|
||||||
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
|
let mel_filters = safetensors::tensor::SafeTensors::deserialize(&md.mel_filters)?;
|
||||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||||
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
||||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||||
let vb = VarBuilder::from_buffered_safetensors(md.weights, DTYPE, &device)?;
|
let config: Config = serde_json::from_slice(&md.config)?;
|
||||||
let config = Config::tiny_en();
|
let model = if md.quantized {
|
||||||
let whisper = Whisper::load(&vb, config)?;
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
|
||||||
|
&md.weights,
|
||||||
|
)?;
|
||||||
|
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||||
|
} else {
|
||||||
|
let vb = VarBuilder::from_buffered_safetensors(md.weights, m::DTYPE, &device)?;
|
||||||
|
Model::Normal(m::model::Whisper::load(&vb, config)?)
|
||||||
|
};
|
||||||
console_log!("done loading model");
|
console_log!("done loading model");
|
||||||
let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?;
|
|
||||||
|
let task = match md.task.as_deref() {
|
||||||
|
Some("translate") => Some(Task::Translate),
|
||||||
|
_ => Some(Task::Transcribe),
|
||||||
|
};
|
||||||
|
|
||||||
|
let decoder = Self::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
mel_filters,
|
||||||
|
&device,
|
||||||
|
task,
|
||||||
|
md.language,
|
||||||
|
md.is_multilingual,
|
||||||
|
md.timestamps,
|
||||||
|
)?;
|
||||||
Ok(decoder)
|
Ok(decoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,8 +341,8 @@ impl Decoder {
|
|||||||
let mut wav_input = std::io::Cursor::new(wav_input);
|
let mut wav_input = std::io::Cursor::new(wav_input);
|
||||||
let (header, data) = wav::read(&mut wav_input)?;
|
let (header, data) = wav::read(&mut wav_input)?;
|
||||||
console_log!("loaded wav data: {header:?}");
|
console_log!("loaded wav data: {header:?}");
|
||||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||||
anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate");
|
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
|
||||||
}
|
}
|
||||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||||
@ -277,27 +352,74 @@ impl Decoder {
|
|||||||
console_log!("pcm data loaded {}", pcm_data.len());
|
console_log!("pcm data loaded {}", pcm_data.len());
|
||||||
let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
|
let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||||
console_log!("loaded mel: {:?}", mel.dims());
|
console_log!("loaded mel: {:?}", mel.dims());
|
||||||
let segments = self.run(&mel)?;
|
let segments = self.run(&mel)?;
|
||||||
Ok(segments)
|
Ok(segments)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the token id for the selected language.
|
||||||
|
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32, E> {
|
||||||
|
console_log!("detecting language");
|
||||||
|
let (_bsize, _, seq_len) = mel.dims3()?;
|
||||||
|
let mel = mel.narrow(
|
||||||
|
2,
|
||||||
|
0,
|
||||||
|
usize::min(seq_len, model.config().max_source_positions),
|
||||||
|
)?;
|
||||||
|
let device = mel.device();
|
||||||
|
|
||||||
|
let language_token_ids = LANGUAGES
|
||||||
|
.iter()
|
||||||
|
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
|
||||||
|
.map(|e| e.map_err(E::msg))
|
||||||
|
.collect::<Result<Vec<_>, E>>()?;
|
||||||
|
|
||||||
|
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
|
||||||
|
let audio_features = model.encoder_forward(&mel, true)?;
|
||||||
|
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||||
|
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||||
|
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
||||||
|
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
||||||
|
let logits = logits.index_select(&language_token_ids, 0)?;
|
||||||
|
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||||
|
let probs = probs.to_vec1::<f32>()?;
|
||||||
|
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
||||||
|
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||||
|
for ((_, language), p) in probs.iter().take(5) {
|
||||||
|
println!("{language}: {p}")
|
||||||
|
}
|
||||||
|
let token = &format!("<|{}|>", probs[0].0 .0);
|
||||||
|
let language = token_id(tokenizer, token)?;
|
||||||
|
console_log!("detected language: {language} {token}");
|
||||||
|
Ok(language)
|
||||||
|
}
|
||||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||||
match tokenizer.token_to_id(token) {
|
match tokenizer.token_to_id(token) {
|
||||||
None => candle::bail!("no token-id for {token}"),
|
None => candle::bail!("no token-id for {token}"),
|
||||||
Some(id) => Ok(id),
|
Some(id) => Ok(id),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
|
||||||
|
pub enum Task {
|
||||||
|
Transcribe,
|
||||||
|
Translate,
|
||||||
|
}
|
||||||
|
|
||||||
// Communication to the worker happens through bincode, the model weights and configs are fetched
|
// Communication to the worker happens through bincode, the model weights and configs are fetched
|
||||||
// on the main thread and transfered via the following structure.
|
// on the main thread and transfered via the following structure.
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct ModelData {
|
pub struct ModelData {
|
||||||
|
pub weights: Vec<u8>,
|
||||||
pub tokenizer: Vec<u8>,
|
pub tokenizer: Vec<u8>,
|
||||||
pub mel_filters: Vec<u8>,
|
pub mel_filters: Vec<u8>,
|
||||||
pub weights: Vec<u8>,
|
pub config: Vec<u8>,
|
||||||
|
pub quantized: bool,
|
||||||
|
pub timestamps: bool,
|
||||||
|
pub is_multilingual: bool,
|
||||||
|
pub language: Option<String>,
|
||||||
|
pub task: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Worker {
|
pub struct Worker {
|
||||||
|
@ -17,23 +17,46 @@ class Whisper {
|
|||||||
static instance = {};
|
static instance = {};
|
||||||
// Retrieve the Whisper model. When called for the first time,
|
// Retrieve the Whisper model. When called for the first time,
|
||||||
// this will load the model and save it for future use.
|
// this will load the model and save it for future use.
|
||||||
static async getInstance(weightsURL, modelID, tokenizerURL, mel_filtersURL) {
|
static async getInstance(params) {
|
||||||
|
const {
|
||||||
|
weightsURL,
|
||||||
|
modelID,
|
||||||
|
tokenizerURL,
|
||||||
|
mel_filtersURL,
|
||||||
|
configURL,
|
||||||
|
quantized,
|
||||||
|
is_multilingual,
|
||||||
|
timestamps,
|
||||||
|
task,
|
||||||
|
language,
|
||||||
|
} = params;
|
||||||
// load individual modelID only once
|
// load individual modelID only once
|
||||||
if (!this.instance[modelID]) {
|
if (!this.instance[modelID]) {
|
||||||
await init();
|
await init();
|
||||||
|
|
||||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||||
const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
|
const [
|
||||||
await Promise.all([
|
weightsArrayU8,
|
||||||
fetchArrayBuffer(weightsURL),
|
tokenizerArrayU8,
|
||||||
fetchArrayBuffer(tokenizerURL),
|
mel_filtersArrayU8,
|
||||||
fetchArrayBuffer(mel_filtersURL),
|
configArrayU8,
|
||||||
]);
|
] = await Promise.all([
|
||||||
|
fetchArrayBuffer(weightsURL),
|
||||||
|
fetchArrayBuffer(tokenizerURL),
|
||||||
|
fetchArrayBuffer(mel_filtersURL),
|
||||||
|
fetchArrayBuffer(configURL),
|
||||||
|
]);
|
||||||
|
|
||||||
this.instance[modelID] = new Decoder(
|
this.instance[modelID] = new Decoder(
|
||||||
weightsArrayU8,
|
weightsArrayU8,
|
||||||
tokenizerArrayU8,
|
tokenizerArrayU8,
|
||||||
mel_filtersArrayU8
|
mel_filtersArrayU8,
|
||||||
|
configArrayU8,
|
||||||
|
quantized,
|
||||||
|
is_multilingual,
|
||||||
|
timestamps,
|
||||||
|
task,
|
||||||
|
language
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
self.postMessage({ status: "loading", message: "Model Already Loaded" });
|
self.postMessage({ status: "loading", message: "Model Already Loaded" });
|
||||||
@ -43,17 +66,37 @@ class Whisper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.addEventListener("message", async (event) => {
|
self.addEventListener("message", async (event) => {
|
||||||
const { weightsURL, modelID, tokenizerURL, mel_filtersURL, audioURL } =
|
const {
|
||||||
event.data;
|
weightsURL,
|
||||||
|
modelID,
|
||||||
|
tokenizerURL,
|
||||||
|
configURL,
|
||||||
|
mel_filtersURL,
|
||||||
|
audioURL,
|
||||||
|
} = event.data;
|
||||||
try {
|
try {
|
||||||
self.postMessage({ status: "decoding", message: "Starting Decoder" });
|
self.postMessage({ status: "decoding", message: "Starting Decoder" });
|
||||||
|
let quantized = false;
|
||||||
const decoder = await Whisper.getInstance(
|
if (modelID.includes("quantized")) {
|
||||||
|
quantized = true;
|
||||||
|
}
|
||||||
|
let is_multilingual = false;
|
||||||
|
if (modelID.includes("multilingual")) {
|
||||||
|
is_multilingual = true;
|
||||||
|
}
|
||||||
|
let timestamps = true;
|
||||||
|
const decoder = await Whisper.getInstance({
|
||||||
weightsURL,
|
weightsURL,
|
||||||
modelID,
|
modelID,
|
||||||
tokenizerURL,
|
tokenizerURL,
|
||||||
mel_filtersURL
|
mel_filtersURL,
|
||||||
);
|
configURL,
|
||||||
|
quantized,
|
||||||
|
is_multilingual,
|
||||||
|
timestamps,
|
||||||
|
task: null,
|
||||||
|
language: null,
|
||||||
|
});
|
||||||
|
|
||||||
self.postMessage({ status: "decoding", message: "Loading Audio" });
|
self.postMessage({ status: "decoding", message: "Loading Audio" });
|
||||||
const audioArrayU8 = await fetchArrayBuffer(audioURL);
|
const audioArrayU8 = await fetchArrayBuffer(audioURL);
|
||||||
|
Reference in New Issue
Block a user