mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
BERT Wasm (#902)
* implement wasm module * add example to workspace * add UI explore semantic similiarity * change status messages * formatting * minor changes
This commit is contained in:
33
candle-wasm-examples/bert/Cargo.toml
Normal file
33
candle-wasm-examples/bert/Cargo.toml
Normal file
@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "candle-wasm-example-bert"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.2.2" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.2.2" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
log = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
serde-wasm-bindgen = "0.6.0"
|
26
candle-wasm-examples/bert/README.md
Normal file
26
candle-wasm-examples/bert/README.md
Normal file
@ -0,0 +1,26 @@
|
||||
## Running BERT with Candle and WASM
|
||||
|
||||
Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime.
|
||||
|
||||
### Vanilla JS and WebWorkers
|
||||
|
||||
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
|
||||
|
||||
```bash
|
||||
sh build-lib.sh
|
||||
```
|
||||
|
||||
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
|
||||
|
||||
```js
|
||||
import init, { Model } from "./build/m.js";
|
||||
```
|
||||
|
||||
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
|
||||
Finally, you can preview the example by running a local HTTP server. For example:
|
||||
|
||||
```bash
|
||||
python -m http.server
|
||||
```
|
||||
|
||||
Then open `http://localhost:8000/lib-example.html` in your browser.
|
77
candle-wasm-examples/bert/bertWorker.js
Normal file
77
candle-wasm-examples/bert/bertWorker.js
Normal file
@ -0,0 +1,77 @@
|
||||
//load Candle Bert Module wasm module
|
||||
import init, { Model } from "./build/m.js";
|
||||
|
||||
async function fetchArrayBuffer(url) {
|
||||
const cacheName = "bert-candle-cache";
|
||||
const cache = await caches.open(cacheName);
|
||||
const cachedResponse = await cache.match(url);
|
||||
if (cachedResponse) {
|
||||
const data = await cachedResponse.arrayBuffer();
|
||||
return new Uint8Array(data);
|
||||
}
|
||||
const res = await fetch(url, { cache: "force-cache" });
|
||||
cache.put(url, res.clone());
|
||||
return new Uint8Array(await res.arrayBuffer());
|
||||
}
|
||||
class Bert {
|
||||
static instance = {};
|
||||
|
||||
static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
|
||||
if (!this.instance[modelID]) {
|
||||
await init();
|
||||
|
||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||
const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
|
||||
await Promise.all([
|
||||
fetchArrayBuffer(weightsURL),
|
||||
fetchArrayBuffer(tokenizerURL),
|
||||
fetchArrayBuffer(configURL),
|
||||
]);
|
||||
|
||||
this.instance[modelID] = new Model(
|
||||
weightsArrayU8,
|
||||
tokenizerArrayU8,
|
||||
mel_filtersArrayU8
|
||||
);
|
||||
} else {
|
||||
self.postMessage({ status: "ready", message: "Model Already Loaded" });
|
||||
}
|
||||
return this.instance[modelID];
|
||||
}
|
||||
}
|
||||
|
||||
self.addEventListener("message", async (event) => {
|
||||
const {
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
sentences,
|
||||
normalize = true,
|
||||
} = event.data;
|
||||
try {
|
||||
self.postMessage({ status: "ready", message: "Starting Bert Model" });
|
||||
const model = await Bert.getInstance(
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID
|
||||
);
|
||||
self.postMessage({
|
||||
status: "embedding",
|
||||
message: "Calculating Embeddings",
|
||||
});
|
||||
const output = model.get_embeddings({
|
||||
sentences: sentences,
|
||||
normalize_embeddings: normalize,
|
||||
});
|
||||
|
||||
self.postMessage({
|
||||
status: "complete",
|
||||
message: "complete",
|
||||
output: output.data,
|
||||
});
|
||||
} catch (e) {
|
||||
self.postMessage({ error: e });
|
||||
}
|
||||
});
|
2
candle-wasm-examples/bert/build-lib.sh
Normal file
2
candle-wasm-examples/bert/build-lib.sh
Normal file
@ -0,0 +1,2 @@
|
||||
cargo build --target wasm32-unknown-unknown --release
|
||||
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
|
368
candle-wasm-examples/bert/lib-example.html
Normal file
368
candle-wasm-examples/bert/lib-example.html
Normal file
@ -0,0 +1,368 @@
|
||||
<html>
|
||||
<head>
|
||||
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
|
||||
<title>Candle Bert</title>
|
||||
</head>
|
||||
<body></body>
|
||||
</html>
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<style>
|
||||
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||
html,
|
||||
body {
|
||||
font-family: "Source Sans 3", sans-serif;
|
||||
}
|
||||
</style>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script type="module" src="./code.js"></script>
|
||||
<script type="module">
|
||||
import { hcl } from "https://cdn.skypack.dev/d3-color@3";
|
||||
import { interpolateReds } from "https://cdn.skypack.dev/d3-scale-chromatic@3";
|
||||
import { scaleLinear } from "https://cdn.skypack.dev/d3-scale@4";
|
||||
import {
|
||||
getModelInfo,
|
||||
getEmbeddings,
|
||||
getWikiText,
|
||||
cosineSimilarity,
|
||||
} from "./utils.js";
|
||||
|
||||
const bertWorker = new Worker("./bertWorker.js", {
|
||||
type: "module",
|
||||
});
|
||||
|
||||
const inputContainerEL = document.querySelector("#input-container");
|
||||
const textAreaEl = document.querySelector("#input-area");
|
||||
const outputAreaEl = document.querySelector("#output-area");
|
||||
const formEl = document.querySelector("#form");
|
||||
const searchInputEl = document.querySelector("#search-input");
|
||||
const formWikiEl = document.querySelector("#form-wiki");
|
||||
const searchWikiEl = document.querySelector("#search-wiki");
|
||||
const outputStatusEl = document.querySelector("#output-status");
|
||||
const modelSelectEl = document.querySelector("#model");
|
||||
|
||||
const sentencesRegex =
|
||||
/(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z]\.)(?<=\.|\?)\s/gm;
|
||||
|
||||
let sentenceEmbeddings = [];
|
||||
let currInputText = "";
|
||||
let isCalculating = false;
|
||||
|
||||
function toggleTextArea(state) {
|
||||
if (state) {
|
||||
textAreaEl.hidden = false;
|
||||
textAreaEl.focus();
|
||||
} else {
|
||||
textAreaEl.hidden = true;
|
||||
}
|
||||
}
|
||||
inputContainerEL.addEventListener("focus", (e) => {
|
||||
toggleTextArea(true);
|
||||
});
|
||||
textAreaEl.addEventListener("blur", (e) => {
|
||||
toggleTextArea(false);
|
||||
});
|
||||
textAreaEl.addEventListener("focusout", (e) => {
|
||||
toggleTextArea(false);
|
||||
if (currInputText === textAreaEl.value || isCalculating) return;
|
||||
populateOutputArea(textAreaEl.value);
|
||||
calculateEmbeddings(textAreaEl.value);
|
||||
});
|
||||
|
||||
modelSelectEl.addEventListener("change", (e) => {
|
||||
if (currInputText === "" || isCalculating) return;
|
||||
populateOutputArea(textAreaEl.value);
|
||||
calculateEmbeddings(textAreaEl.value);
|
||||
});
|
||||
|
||||
function populateOutputArea(text) {
|
||||
currInputText = text;
|
||||
const sentences = text.split(sentencesRegex);
|
||||
|
||||
outputAreaEl.innerHTML = "";
|
||||
for (const [id, sentence] of sentences.entries()) {
|
||||
const sentenceEl = document.createElement("span");
|
||||
sentenceEl.id = `sentence-${id}`;
|
||||
sentenceEl.innerText = sentence + " ";
|
||||
outputAreaEl.appendChild(sentenceEl);
|
||||
}
|
||||
}
|
||||
formEl.addEventListener("submit", async (e) => {
|
||||
e.preventDefault();
|
||||
if (isCalculating || currInputText === "") return;
|
||||
toggleInputs(true);
|
||||
const modelID = modelSelectEl.value;
|
||||
const { modelURL, tokenizerURL, configURL, search_prefix } =
|
||||
getModelInfo(modelID);
|
||||
|
||||
const text = searchInputEl.value;
|
||||
const query = search_prefix + searchInputEl.value;
|
||||
outputStatusEl.classList.remove("invisible");
|
||||
outputStatusEl.innerText = "Calculating embeddings for query...";
|
||||
isCalculating = true;
|
||||
const out = await getEmbeddings(
|
||||
bertWorker,
|
||||
modelURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
[query]
|
||||
);
|
||||
outputStatusEl.classList.add("invisible");
|
||||
const queryEmbeddings = out.output[0];
|
||||
// calculate cosine similarity with all sentences given the query
|
||||
const distances = sentenceEmbeddings
|
||||
.map((embedding, id) => ({
|
||||
id,
|
||||
similarity: cosineSimilarity(queryEmbeddings, embedding),
|
||||
}))
|
||||
.sort((a, b) => b.similarity - a.similarity)
|
||||
// getting top 10 most similar sentences
|
||||
.slice(0, 10);
|
||||
|
||||
const colorScale = scaleLinear()
|
||||
.domain([
|
||||
distances[distances.length - 1].similarity,
|
||||
distances[0].similarity,
|
||||
])
|
||||
.range([0, 1])
|
||||
.interpolate(() => interpolateReds);
|
||||
outputAreaEl.querySelectorAll("span").forEach((el) => {
|
||||
el.style.color = "unset";
|
||||
el.style.backgroundColor = "unset";
|
||||
});
|
||||
distances.forEach((d) => {
|
||||
const el = outputAreaEl.querySelector(`#sentence-${d.id}`);
|
||||
const color = colorScale(d.similarity);
|
||||
const fontColor = hcl(color).l < 70 ? "white" : "black";
|
||||
el.style.color = fontColor;
|
||||
el.style.backgroundColor = color;
|
||||
});
|
||||
|
||||
outputAreaEl
|
||||
.querySelector(`#sentence-${distances[0].id}`)
|
||||
.scrollIntoView({
|
||||
behavior: "smooth",
|
||||
block: "center",
|
||||
inline: "nearest",
|
||||
});
|
||||
|
||||
isCalculating = false;
|
||||
toggleInputs(false);
|
||||
});
|
||||
async function calculateEmbeddings(text) {
|
||||
isCalculating = true;
|
||||
toggleInputs(true);
|
||||
const modelID = modelSelectEl.value;
|
||||
const { modelURL, tokenizerURL, configURL, document_prefix } =
|
||||
getModelInfo(modelID);
|
||||
|
||||
const sentences = text.split(sentencesRegex);
|
||||
const allEmbeddings = [];
|
||||
outputStatusEl.classList.remove("invisible");
|
||||
for (const [id, sentence] of sentences.entries()) {
|
||||
const query = document_prefix + sentence;
|
||||
outputStatusEl.innerText = `Calculating embeddings: sentence ${
|
||||
id + 1
|
||||
} of ${sentences.length}`;
|
||||
const embeddings = await getEmbeddings(
|
||||
bertWorker,
|
||||
modelURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
[query],
|
||||
updateStatus
|
||||
);
|
||||
allEmbeddings.push(embeddings);
|
||||
}
|
||||
outputStatusEl.classList.add("invisible");
|
||||
sentenceEmbeddings = allEmbeddings.map((e) => e.output[0]);
|
||||
isCalculating = false;
|
||||
toggleInputs(false);
|
||||
}
|
||||
|
||||
function updateStatus(data) {
|
||||
if ("status" in data) {
|
||||
if (data.status === "loading") {
|
||||
outputStatusEl.innerText = data.message;
|
||||
outputStatusEl.classList.remove("invisible");
|
||||
}
|
||||
}
|
||||
}
|
||||
function toggleInputs(state) {
|
||||
const interactive = document.querySelectorAll(".interactive");
|
||||
interactive.forEach((el) => {
|
||||
if (state) {
|
||||
el.disabled = true;
|
||||
} else {
|
||||
el.disabled = false;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
searchWikiEl.addEventListener("input", () => {
|
||||
searchWikiEl.setCustomValidity("");
|
||||
});
|
||||
|
||||
formWikiEl.addEventListener("submit", async (e) => {
|
||||
e.preventDefault();
|
||||
if ("example" in e.submitter.dataset) {
|
||||
searchWikiEl.value = e.submitter.innerText;
|
||||
}
|
||||
const text = searchWikiEl.value;
|
||||
|
||||
if (isCalculating || text === "") return;
|
||||
try {
|
||||
const wikiText = await getWikiText(text);
|
||||
searchWikiEl.setCustomValidity("");
|
||||
textAreaEl.innerHTML = wikiText;
|
||||
populateOutputArea(wikiText);
|
||||
calculateEmbeddings(wikiText);
|
||||
searchWikiEl.value = "";
|
||||
} catch {
|
||||
searchWikiEl.setCustomValidity("Invalid Wikipedia article name");
|
||||
searchWikiEl.reportValidity();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</head>
|
||||
<body class="container max-w-4xl mx-auto p-4">
|
||||
<main class="grid grid-cols-1 gap-5 relative">
|
||||
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
|
||||
<div>
|
||||
<h1 class="text-5xl font-bold">Candle BERT</h1>
|
||||
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
|
||||
<p class="max-w-lg">
|
||||
Running sentence embeddings and similarity search in the browser using
|
||||
the Bert Model written with
|
||||
<a
|
||||
href="https://github.com/huggingface/candle/"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>Candle
|
||||
</a>
|
||||
and compiled to Wasm. Embeddings models from are from
|
||||
<a
|
||||
href="https://huggingface.co/sentence-transformers/"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>
|
||||
Sentence Transformers
|
||||
</a>
|
||||
and
|
||||
<a
|
||||
href="https://huggingface.co/intfloat/"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>
|
||||
Liang Wang - e5 Models
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="model" class="font-medium block">Models Options: </label>
|
||||
<select
|
||||
id="model"
|
||||
class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max"
|
||||
>
|
||||
<option value="intfloat_e5_small_v2" selected>
|
||||
intfloat/e5-small-v2 (133 MB)
|
||||
</option>
|
||||
<option value="intfloat_e5_base_v2">
|
||||
intfloat/e5-base-v2 (438 MB)
|
||||
</option>
|
||||
<option value="intfloat_multilingual_e5_small">
|
||||
intfloat/multilingual-e5-small (471 MB)
|
||||
</option>
|
||||
<option value="sentence_transformers_all_MiniLM_L6_v2">
|
||||
sentence-transformers/all-MiniLM-L6-v2 (90.9 MB)
|
||||
</option>
|
||||
<option value="sentence_transformers_all_MiniLM_L12_v2">
|
||||
sentence-transformers/all-MiniLM-L12-v2 (133 MB)
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="font-medium">Examples:</h3>
|
||||
<form
|
||||
id="form-wiki"
|
||||
class="flex text-xs rounded-md justify-between w-min gap-3"
|
||||
>
|
||||
<input type="submit" hidden />
|
||||
|
||||
<button data-example class="disabled:cursor-not-allowed interactive">
|
||||
Pizza
|
||||
</button>
|
||||
<button data-example class="disabled:cursor-not-allowed interactive">
|
||||
Paris
|
||||
</button>
|
||||
<button data-example class="disabled:cursor-not-allowed interactive">
|
||||
Physics
|
||||
</button>
|
||||
<input
|
||||
type="text"
|
||||
id="search-wiki"
|
||||
title="Search Wikipedia article by title"
|
||||
class="font-light py-0 mx-1 resize-none outline-none w-32 disabled:cursor-not-allowed interactive"
|
||||
placeholder="Load Wikipedia article..."
|
||||
/>
|
||||
<button
|
||||
title="Search Wikipedia article and load into input"
|
||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal px-2 py-1 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
|
||||
>
|
||||
Load
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
<form
|
||||
id="form"
|
||||
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
|
||||
>
|
||||
<input type="submit" hidden />
|
||||
<input
|
||||
type="text"
|
||||
id="search-input"
|
||||
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none interactive disabled:cursor-not-allowed"
|
||||
placeholder="Search query here..."
|
||||
/>
|
||||
<button
|
||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
|
||||
>
|
||||
Search
|
||||
</button>
|
||||
</form>
|
||||
<div>
|
||||
<h3 class="font-medium">Input text:</h3>
|
||||
<div class="flex justify-between items-center">
|
||||
<div class="rounded-md inline text-xs">
|
||||
<span id="output-status" class="m-auto font-light invisible"
|
||||
>C</span
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
id="input-container"
|
||||
tabindex="0"
|
||||
class="min-h-[250px] bg-slate-100 text-gray-500 rounded-md p-4 flex flex-col gap-2 relative"
|
||||
>
|
||||
<textarea
|
||||
id="input-area"
|
||||
hidden
|
||||
value=""
|
||||
placeholder="Input text to perform semantic similarity search..."
|
||||
class="flex-1 resize-none outline-none left-0 right-0 top-0 bottom-0 m-4 absolute interactive disabled:invisible"
|
||||
></textarea>
|
||||
<p id="output-area" class="grid-rows-2">
|
||||
Input text to perform semantic similarity search...
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
92
candle-wasm-examples/bert/src/bin/m.rs
Normal file
92
candle-wasm-examples/bert/src/bin/m.rs
Normal file
@ -0,0 +1,92 @@
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::bert::{BertModel, Config};
|
||||
use candle_wasm_example_bert::console_log;
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct Model {
|
||||
bert: BertModel,
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, config: Vec<u8>) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
|
||||
let config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let bert = BertModel::load(vb, &config)?;
|
||||
|
||||
Ok(Self { bert, tokenizer })
|
||||
}
|
||||
|
||||
pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: Params =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let sentences = input.sentences;
|
||||
let normalize_embeddings = input.normalize_embeddings;
|
||||
|
||||
let device = &Device::Cpu;
|
||||
if let Some(pp) = self.tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
self.tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
||||
let token_ids: Vec<Tensor> = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
console_log!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
|
||||
console_log!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if normalize_embeddings {
|
||||
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
let embeddings_data = embeddings.to_vec2()?;
|
||||
Ok(serde_wasm_bindgen::to_value(&Embeddings {
|
||||
data: embeddings_data,
|
||||
})?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct Embeddings {
|
||||
data: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct Params {
|
||||
sentences: Vec<String>,
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
20
candle-wasm-examples/bert/src/lib.rs
Normal file
20
candle-wasm-examples/bert/src/lib.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle_transformers::models::bert;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
pub use bert::{BertModel, Config, DTYPE};
|
||||
pub use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
// Use `js_namespace` here to bind `console.log(..)` instead of just
|
||||
// `log(..)`
|
||||
#[wasm_bindgen(js_namespace = console)]
|
||||
pub fn log(s: &str);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
// Note that this is using the `log` function imported above during
|
||||
// `bare_bones`
|
||||
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
|
||||
}
|
99
candle-wasm-examples/bert/utils.js
Normal file
99
candle-wasm-examples/bert/utils.js
Normal file
@ -0,0 +1,99 @@
|
||||
export async function getEmbeddings(
|
||||
worker,
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
sentences,
|
||||
updateStatus = null
|
||||
) {
|
||||
return new Promise((resolve, reject) => {
|
||||
worker.postMessage({
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
sentences,
|
||||
});
|
||||
function messageHandler(event) {
|
||||
if ("error" in event.data) {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
reject(new Error(event.data.error));
|
||||
}
|
||||
if (event.data.status === "complete") {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
resolve(event.data);
|
||||
}
|
||||
if (updateStatus) updateStatus(event.data);
|
||||
}
|
||||
worker.addEventListener("message", messageHandler);
|
||||
});
|
||||
}
|
||||
|
||||
const MODELS = {
|
||||
intfloat_e5_small_v2: {
|
||||
base_url: "https://huggingface.co/intfloat/e5-small-v2/resolve/main/",
|
||||
search_prefix: "query: ",
|
||||
document_prefix: "passage: ",
|
||||
},
|
||||
intfloat_e5_base_v2: {
|
||||
base_url: "https://huggingface.co/intfloat/e5-base-v2/resolve/main/",
|
||||
search_prefix: "query: ",
|
||||
document_prefix: "passage:",
|
||||
},
|
||||
intfloat_multilingual_e5_small: {
|
||||
base_url:
|
||||
"https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/",
|
||||
search_prefix: "query: ",
|
||||
document_prefix: "passage: ",
|
||||
},
|
||||
sentence_transformers_all_MiniLM_L6_v2: {
|
||||
base_url:
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/refs%2Fpr%2F21/",
|
||||
search_prefix: "",
|
||||
document_prefix: "",
|
||||
},
|
||||
sentence_transformers_all_MiniLM_L12_v2: {
|
||||
base_url:
|
||||
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/refs%2Fpr%2F4/",
|
||||
search_prefix: "",
|
||||
document_prefix: "",
|
||||
},
|
||||
};
|
||||
export function getModelInfo(id) {
|
||||
return {
|
||||
modelURL: MODELS[id].base_url + "model.safetensors",
|
||||
configURL: MODELS[id].base_url + "config.json",
|
||||
tokenizerURL: MODELS[id].base_url + "tokenizer.json",
|
||||
search_prefix: MODELS[id].search_prefix,
|
||||
document_prefix: MODELS[id].document_prefix,
|
||||
};
|
||||
}
|
||||
|
||||
export function cosineSimilarity(vec1, vec2) {
|
||||
const dot = vec1.reduce((acc, val, i) => acc + val * vec2[i], 0);
|
||||
const a = Math.sqrt(vec1.reduce((acc, val) => acc + val * val, 0));
|
||||
const b = Math.sqrt(vec2.reduce((acc, val) => acc + val * val, 0));
|
||||
return dot / (a * b);
|
||||
}
|
||||
export async function getWikiText(article) {
|
||||
// thanks to wikipedia for the API
|
||||
const URL = `https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exlimit=1&titles=${article}&explaintext=1&exsectionformat=plain&format=json&origin=*`;
|
||||
return fetch(URL, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
},
|
||||
})
|
||||
.then((r) => r.json())
|
||||
.then((data) => {
|
||||
const pages = data.query.pages;
|
||||
const pageId = Object.keys(pages)[0];
|
||||
const extract = pages[pageId].extract;
|
||||
if (extract === undefined || extract === "") {
|
||||
throw new Error("No article found");
|
||||
}
|
||||
return extract;
|
||||
})
|
||||
.catch((error) => console.error("Error:", error));
|
||||
}
|
Reference in New Issue
Block a user