BERT Wasm (#902)

* implement wasm module

* add example to workspace

* add UI explore semantic similiarity

* change status messages

* formatting

* minor changes
This commit is contained in:
Radamés Ajna
2023-09-19 13:31:37 -07:00
committed by GitHub
parent 8696f64bae
commit 7ad82b87e4
9 changed files with 719 additions and 4 deletions

View File

@ -11,11 +11,9 @@ members = [
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
[workspace.package]

View File

@ -0,0 +1,33 @@
[package]
name = "candle-wasm-example-bert"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.2.2" }
candle-transformers = { path = "../../candle-transformers", version = "0.2.2" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
# App crates.
anyhow = { workspace = true }
byteorder = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0"

View File

@ -0,0 +1,26 @@
## Running BERT with Candle and WASM
Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime.
### Vanilla JS and WebWorkers
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
```bash
sh build-lib.sh
```
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
```js
import init, { Model } from "./build/m.js";
```
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
Finally, you can preview the example by running a local HTTP server. For example:
```bash
python -m http.server
```
Then open `http://localhost:8000/lib-example.html` in your browser.

View File

@ -0,0 +1,77 @@
//load Candle Bert Module wasm module
import init, { Model } from "./build/m.js";
async function fetchArrayBuffer(url) {
const cacheName = "bert-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
class Bert {
static instance = {};
static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new Model(
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8
);
} else {
self.postMessage({ status: "ready", message: "Model Already Loaded" });
}
return this.instance[modelID];
}
}
self.addEventListener("message", async (event) => {
const {
weightsURL,
tokenizerURL,
configURL,
modelID,
sentences,
normalize = true,
} = event.data;
try {
self.postMessage({ status: "ready", message: "Starting Bert Model" });
const model = await Bert.getInstance(
weightsURL,
tokenizerURL,
configURL,
modelID
);
self.postMessage({
status: "embedding",
message: "Calculating Embeddings",
});
const output = model.get_embeddings({
sentences: sentences,
normalize_embeddings: normalize,
});
self.postMessage({
status: "complete",
message: "complete",
output: output.data,
});
} catch (e) {
self.postMessage({ error: e });
}
});

View File

@ -0,0 +1,2 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web

View File

@ -0,0 +1,368 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
<title>Candle Bert</title>
</head>
<body></body>
</html>
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
body {
font-family: "Source Sans 3", sans-serif;
}
</style>
<script src="https://cdn.tailwindcss.com"></script>
<script type="module" src="./code.js"></script>
<script type="module">
import { hcl } from "https://cdn.skypack.dev/d3-color@3";
import { interpolateReds } from "https://cdn.skypack.dev/d3-scale-chromatic@3";
import { scaleLinear } from "https://cdn.skypack.dev/d3-scale@4";
import {
getModelInfo,
getEmbeddings,
getWikiText,
cosineSimilarity,
} from "./utils.js";
const bertWorker = new Worker("./bertWorker.js", {
type: "module",
});
const inputContainerEL = document.querySelector("#input-container");
const textAreaEl = document.querySelector("#input-area");
const outputAreaEl = document.querySelector("#output-area");
const formEl = document.querySelector("#form");
const searchInputEl = document.querySelector("#search-input");
const formWikiEl = document.querySelector("#form-wiki");
const searchWikiEl = document.querySelector("#search-wiki");
const outputStatusEl = document.querySelector("#output-status");
const modelSelectEl = document.querySelector("#model");
const sentencesRegex =
/(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z]\.)(?<=\.|\?)\s/gm;
let sentenceEmbeddings = [];
let currInputText = "";
let isCalculating = false;
function toggleTextArea(state) {
if (state) {
textAreaEl.hidden = false;
textAreaEl.focus();
} else {
textAreaEl.hidden = true;
}
}
inputContainerEL.addEventListener("focus", (e) => {
toggleTextArea(true);
});
textAreaEl.addEventListener("blur", (e) => {
toggleTextArea(false);
});
textAreaEl.addEventListener("focusout", (e) => {
toggleTextArea(false);
if (currInputText === textAreaEl.value || isCalculating) return;
populateOutputArea(textAreaEl.value);
calculateEmbeddings(textAreaEl.value);
});
modelSelectEl.addEventListener("change", (e) => {
if (currInputText === "" || isCalculating) return;
populateOutputArea(textAreaEl.value);
calculateEmbeddings(textAreaEl.value);
});
function populateOutputArea(text) {
currInputText = text;
const sentences = text.split(sentencesRegex);
outputAreaEl.innerHTML = "";
for (const [id, sentence] of sentences.entries()) {
const sentenceEl = document.createElement("span");
sentenceEl.id = `sentence-${id}`;
sentenceEl.innerText = sentence + " ";
outputAreaEl.appendChild(sentenceEl);
}
}
formEl.addEventListener("submit", async (e) => {
e.preventDefault();
if (isCalculating || currInputText === "") return;
toggleInputs(true);
const modelID = modelSelectEl.value;
const { modelURL, tokenizerURL, configURL, search_prefix } =
getModelInfo(modelID);
const text = searchInputEl.value;
const query = search_prefix + searchInputEl.value;
outputStatusEl.classList.remove("invisible");
outputStatusEl.innerText = "Calculating embeddings for query...";
isCalculating = true;
const out = await getEmbeddings(
bertWorker,
modelURL,
tokenizerURL,
configURL,
modelID,
[query]
);
outputStatusEl.classList.add("invisible");
const queryEmbeddings = out.output[0];
// calculate cosine similarity with all sentences given the query
const distances = sentenceEmbeddings
.map((embedding, id) => ({
id,
similarity: cosineSimilarity(queryEmbeddings, embedding),
}))
.sort((a, b) => b.similarity - a.similarity)
// getting top 10 most similar sentences
.slice(0, 10);
const colorScale = scaleLinear()
.domain([
distances[distances.length - 1].similarity,
distances[0].similarity,
])
.range([0, 1])
.interpolate(() => interpolateReds);
outputAreaEl.querySelectorAll("span").forEach((el) => {
el.style.color = "unset";
el.style.backgroundColor = "unset";
});
distances.forEach((d) => {
const el = outputAreaEl.querySelector(`#sentence-${d.id}`);
const color = colorScale(d.similarity);
const fontColor = hcl(color).l < 70 ? "white" : "black";
el.style.color = fontColor;
el.style.backgroundColor = color;
});
outputAreaEl
.querySelector(`#sentence-${distances[0].id}`)
.scrollIntoView({
behavior: "smooth",
block: "center",
inline: "nearest",
});
isCalculating = false;
toggleInputs(false);
});
async function calculateEmbeddings(text) {
isCalculating = true;
toggleInputs(true);
const modelID = modelSelectEl.value;
const { modelURL, tokenizerURL, configURL, document_prefix } =
getModelInfo(modelID);
const sentences = text.split(sentencesRegex);
const allEmbeddings = [];
outputStatusEl.classList.remove("invisible");
for (const [id, sentence] of sentences.entries()) {
const query = document_prefix + sentence;
outputStatusEl.innerText = `Calculating embeddings: sentence ${
id + 1
} of ${sentences.length}`;
const embeddings = await getEmbeddings(
bertWorker,
modelURL,
tokenizerURL,
configURL,
modelID,
[query],
updateStatus
);
allEmbeddings.push(embeddings);
}
outputStatusEl.classList.add("invisible");
sentenceEmbeddings = allEmbeddings.map((e) => e.output[0]);
isCalculating = false;
toggleInputs(false);
}
function updateStatus(data) {
if ("status" in data) {
if (data.status === "loading") {
outputStatusEl.innerText = data.message;
outputStatusEl.classList.remove("invisible");
}
}
}
function toggleInputs(state) {
const interactive = document.querySelectorAll(".interactive");
interactive.forEach((el) => {
if (state) {
el.disabled = true;
} else {
el.disabled = false;
}
});
}
searchWikiEl.addEventListener("input", () => {
searchWikiEl.setCustomValidity("");
});
formWikiEl.addEventListener("submit", async (e) => {
e.preventDefault();
if ("example" in e.submitter.dataset) {
searchWikiEl.value = e.submitter.innerText;
}
const text = searchWikiEl.value;
if (isCalculating || text === "") return;
try {
const wikiText = await getWikiText(text);
searchWikiEl.setCustomValidity("");
textAreaEl.innerHTML = wikiText;
populateOutputArea(wikiText);
calculateEmbeddings(wikiText);
searchWikiEl.value = "";
} catch {
searchWikiEl.setCustomValidity("Invalid Wikipedia article name");
searchWikiEl.reportValidity();
}
});
</script>
</head>
<body class="container max-w-4xl mx-auto p-4">
<main class="grid grid-cols-1 gap-5 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle BERT</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
Running sentence embeddings and similarity search in the browser using
the Bert Model written with
<a
href="https://github.com/huggingface/candle/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>Candle
</a>
and compiled to Wasm. Embeddings models from are from
<a
href="https://huggingface.co/sentence-transformers/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>
Sentence Transformers
</a>
and
<a
href="https://huggingface.co/intfloat/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>
Liang Wang - e5 Models
</a>
</p>
</div>
<div>
<label for="model" class="font-medium block">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max"
>
<option value="intfloat_e5_small_v2" selected>
intfloat/e5-small-v2 (133 MB)
</option>
<option value="intfloat_e5_base_v2">
intfloat/e5-base-v2 (438 MB)
</option>
<option value="intfloat_multilingual_e5_small">
intfloat/multilingual-e5-small (471 MB)
</option>
<option value="sentence_transformers_all_MiniLM_L6_v2">
sentence-transformers/all-MiniLM-L6-v2 (90.9 MB)
</option>
<option value="sentence_transformers_all_MiniLM_L12_v2">
sentence-transformers/all-MiniLM-L12-v2 (133 MB)
</option>
</select>
</div>
<div>
<h3 class="font-medium">Examples:</h3>
<form
id="form-wiki"
class="flex text-xs rounded-md justify-between w-min gap-3"
>
<input type="submit" hidden />
<button data-example class="disabled:cursor-not-allowed interactive">
Pizza
</button>
<button data-example class="disabled:cursor-not-allowed interactive">
Paris
</button>
<button data-example class="disabled:cursor-not-allowed interactive">
Physics
</button>
<input
type="text"
id="search-wiki"
title="Search Wikipedia article by title"
class="font-light py-0 mx-1 resize-none outline-none w-32 disabled:cursor-not-allowed interactive"
placeholder="Load Wikipedia article..."
/>
<button
title="Search Wikipedia article and load into input"
class="bg-gray-700 hover:bg-gray-800 text-white font-normal px-2 py-1 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
>
Load
</button>
</form>
</div>
<form
id="form"
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
>
<input type="submit" hidden />
<input
type="text"
id="search-input"
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none interactive disabled:cursor-not-allowed"
placeholder="Search query here..."
/>
<button
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
>
Search
</button>
</form>
<div>
<h3 class="font-medium">Input text:</h3>
<div class="flex justify-between items-center">
<div class="rounded-md inline text-xs">
<span id="output-status" class="m-auto font-light invisible"
>C</span
>
</div>
</div>
<div
id="input-container"
tabindex="0"
class="min-h-[250px] bg-slate-100 text-gray-500 rounded-md p-4 flex flex-col gap-2 relative"
>
<textarea
id="input-area"
hidden
value=""
placeholder="Input text to perform semantic similarity search..."
class="flex-1 resize-none outline-none left-0 right-0 top-0 bottom-0 m-4 absolute interactive disabled:invisible"
></textarea>
<p id="output-area" class="grid-rows-2">
Input text to perform semantic similarity search...
</p>
</div>
</div>
</main>
</body>
</html>

View File

@ -0,0 +1,92 @@
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use candle_wasm_example_bert::console_log;
use tokenizers::{PaddingParams, Tokenizer};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct Model {
bert: BertModel,
tokenizer: Tokenizer,
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, config: Vec<u8>) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
let config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let bert = BertModel::load(vb, &config)?;
Ok(Self { bert, tokenizer })
}
pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: Params =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let sentences = input.sentences;
let normalize_embeddings = input.normalize_embeddings;
let device = &Device::Cpu;
if let Some(pp) = self.tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
self.tokenizer.with_padding(Some(pp));
}
let tokens = self
.tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(|m| JsError::new(&m.to_string()))?;
let token_ids: Vec<Tensor> = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<Result<Vec<_>, _>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let token_type_ids = token_ids.zeros_like()?;
console_log!("running inference on batch {:?}", token_ids.shape());
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if normalize_embeddings {
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
} else {
embeddings
};
let embeddings_data = embeddings.to_vec2()?;
Ok(serde_wasm_bindgen::to_value(&Embeddings {
data: embeddings_data,
})?)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct Embeddings {
data: Vec<Vec<f64>>,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct Params {
sentences: Vec<String>,
normalize_embeddings: bool,
}
fn main() {
console_error_panic_hook::set_once();
}

View File

@ -0,0 +1,20 @@
use candle_transformers::models::bert;
use wasm_bindgen::prelude::*;
pub use bert::{BertModel, Config, DTYPE};
pub use tokenizers::{PaddingParams, Tokenizer};
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
}

View File

@ -0,0 +1,99 @@
export async function getEmbeddings(
worker,
weightsURL,
tokenizerURL,
configURL,
modelID,
sentences,
updateStatus = null
) {
return new Promise((resolve, reject) => {
worker.postMessage({
weightsURL,
tokenizerURL,
configURL,
modelID,
sentences,
});
function messageHandler(event) {
if ("error" in event.data) {
worker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
worker.removeEventListener("message", messageHandler);
resolve(event.data);
}
if (updateStatus) updateStatus(event.data);
}
worker.addEventListener("message", messageHandler);
});
}
const MODELS = {
intfloat_e5_small_v2: {
base_url: "https://huggingface.co/intfloat/e5-small-v2/resolve/main/",
search_prefix: "query: ",
document_prefix: "passage: ",
},
intfloat_e5_base_v2: {
base_url: "https://huggingface.co/intfloat/e5-base-v2/resolve/main/",
search_prefix: "query: ",
document_prefix: "passage:",
},
intfloat_multilingual_e5_small: {
base_url:
"https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/",
search_prefix: "query: ",
document_prefix: "passage: ",
},
sentence_transformers_all_MiniLM_L6_v2: {
base_url:
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/refs%2Fpr%2F21/",
search_prefix: "",
document_prefix: "",
},
sentence_transformers_all_MiniLM_L12_v2: {
base_url:
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/refs%2Fpr%2F4/",
search_prefix: "",
document_prefix: "",
},
};
export function getModelInfo(id) {
return {
modelURL: MODELS[id].base_url + "model.safetensors",
configURL: MODELS[id].base_url + "config.json",
tokenizerURL: MODELS[id].base_url + "tokenizer.json",
search_prefix: MODELS[id].search_prefix,
document_prefix: MODELS[id].document_prefix,
};
}
export function cosineSimilarity(vec1, vec2) {
const dot = vec1.reduce((acc, val, i) => acc + val * vec2[i], 0);
const a = Math.sqrt(vec1.reduce((acc, val) => acc + val * val, 0));
const b = Math.sqrt(vec2.reduce((acc, val) => acc + val * val, 0));
return dot / (a * b);
}
export async function getWikiText(article) {
// thanks to wikipedia for the API
const URL = `https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exlimit=1&titles=${article}&explaintext=1&exsectionformat=plain&format=json&origin=*`;
return fetch(URL, {
method: "GET",
headers: {
Accept: "application/json",
},
})
.then((r) => r.json())
.then((data) => {
const pages = data.query.pages;
const pageId = Object.keys(pages)[0];
const extract = pages[pageId].extract;
if (extract === undefined || extract === "") {
throw new Error("No article found");
}
return extract;
})
.catch((error) => console.error("Error:", error));
}