Moondream WASM (#1999)

* moondream wasm wip

* examples, more

* fix eos token check

* README

* cleanip

* cleanup, clippy
This commit is contained in:
Radamés Ajna
2024-04-02 22:11:50 -07:00
committed by GitHub
parent cd6b9e317c
commit 26226068a4
8 changed files with 1128 additions and 0 deletions

View File

@ -0,0 +1,32 @@
[package]
name = "candle-wasm-example-moondream"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }
# App crates.
anyhow = { workspace = true }
byteorder = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }
image = { workspace = true }
log = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
wasm-bindgen = "0.2.87"
js-sys = "0.3.64"
serde-wasm-bindgen = "0.6.5"

View File

@ -0,0 +1,24 @@
## Running [Moondream 2](https://huggingface.co/vikhyatk/moondream2) Model Example
### Vanilla JS and WebWorkers
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
```bash
sh build-lib.sh
```
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
```js
import init, { Model } from "./build/m.js";
```
The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything.
Finally, you can preview the example by running a local HTTP server. For example:
```bash
python -m http.server
```
Then open `http://localhost:8000/index.html` in your browser.

View File

@ -0,0 +1,2 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web

View File

@ -0,0 +1,262 @@
import snarkdown from "https://cdn.skypack.dev/snarkdown";
import hljs from "https://cdn.skypack.dev/highlight.js";
// models base url
const MODELS = {
moondream2_q4k: {
base_url:
"https://huggingface.co/santiagomed/candle-moondream/resolve/main/",
model: "model-q4_0.gguf",
tokenizer: "tokenizer.json",
quantized: true,
size: "1.51 GB",
},
};
const moodreamWorker = new Worker("./moondreamWorker.js", {
type: "module",
});
async function generateSequence(controller) {
const getValue = (id) => document.querySelector(`#${id}`).value;
const modelID = getValue("model");
const model = MODELS[modelID];
const weightsURL =
model.model instanceof Array
? model.model.map((m) => model.base_url + m)
: model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const prompt = getValue("prompt").trim();
const temperature = getValue("temperature");
const topP = getValue("top-p");
const repeatPenalty = getValue("repeat_penalty");
const seed = getValue("seed");
const maxSeqLen = getValue("max-seq");
if (prompt?.value?.trim() === "") {
return;
}
function updateStatus(data) {
const outStatus = document.querySelector("#output-status");
const outGen = document.querySelector("#output-generation");
const outCounter = document.querySelector("#output-counter");
switch (data.status) {
case "loading":
outStatus.hidden = false;
outStatus.textContent = data.message;
outGen.hidden = true;
outCounter.hidden = true;
break;
case "generating":
const { message, prompt, sentence, tokensSec, totalTime } = data;
outStatus.hidden = true;
outCounter.hidden = false;
outGen.hidden = false;
outGen.innerHTML = snarkdown(prompt + sentence);
outCounter.innerHTML = `${(totalTime / 1000).toFixed(
2
)}s (${tokensSec.toFixed(2)} tok/s)`;
hljs.highlightAll();
break;
case "complete":
outStatus.hidden = true;
outGen.hidden = false;
break;
}
}
return new Promise((resolve, reject) => {
moodreamWorker.postMessage({
weightsURL,
modelID,
tokenizerURL,
quantized: model.quantized,
imageURL: currentImageURL,
prompt,
temp: temperature,
top_p: topP,
repeatPenalty,
seed: seed,
maxSeqLen,
verbose_prompt: false,
command: "start",
});
const handleAbort = () => {
moodreamWorker.postMessage({ command: "abort" });
};
const handleMessage = (event) => {
const { status, error, message, prompt, sentence } = event.data;
if (status) updateStatus(event.data);
if (error) {
moodreamWorker.removeEventListener("message", handleMessage);
reject(new Error(error));
}
if (status === "aborted") {
moodreamWorker.removeEventListener("message", handleMessage);
resolve(event.data);
}
if (status === "complete") {
moodreamWorker.removeEventListener("message", handleMessage);
resolve(event.data);
}
};
controller.signal.addEventListener("abort", handleAbort);
moodreamWorker.addEventListener("message", handleMessage);
});
}
const form = document.querySelector("#form");
const prompt = document.querySelector("#prompt");
const runBtn = document.querySelector("#run");
const modelSelect = document.querySelector("#model");
const dropArea = document.querySelector("#drop-area");
const canvas = document.querySelector("#canvas");
const ctxCanvas = canvas.getContext("2d");
const fileUpload = document.querySelector("#file-upload");
const clearImgBtn = document.querySelector("#clear-img-btn");
const imagesExamples = document.querySelector("#image-select");
let currentImageURL = null;
let runController = new AbortController();
let isRunning = false;
document.addEventListener("DOMContentLoaded", () => {
for (const [id, model] of Object.entries(MODELS)) {
const option = document.createElement("option");
option.value = id;
option.innerText = `${id} (${model.size})`;
modelSelect.appendChild(option);
}
const query = new URLSearchParams(window.location.search);
const modelID = query.get("model");
if (modelID) {
modelSelect.value = modelID;
} else {
modelSelect.value = "moondream2_q4k";
}
});
imagesExamples.addEventListener("click", (e) => {
// if (isEmbedding || isSegmenting) {
// return;
// }
const target = e.target;
if (target.nodeName === "IMG") {
const href = target.src;
clearImageCanvas();
currentImageURL = href;
drawImageCanvas(href);
}
});
modelSelect.addEventListener("change", (e) => {
const query = new URLSearchParams(window.location.search);
query.set("model", e.target.value);
window.history.replaceState({}, "", `${window.location.pathname}?${query}`);
window.parent.postMessage({ queryString: "?" + query }, "*");
const model = MODELS[e.target.value];
document.querySelector("#max-seq").max = model.seq_len;
document.querySelector("#max-seq").nextElementSibling.value = 200;
});
clearImgBtn.addEventListener("click", () => {
clearImageCanvas();
});
//add event listener to file input
fileUpload.addEventListener("input", async (e) => {
const target = e.target;
if (target.files.length > 0 && !target.files[0].type.includes("svg")) {
const href = URL.createObjectURL(target.files[0]);
clearImageCanvas();
await drawImageCanvas(href);
}
});
// add event listener to drop-area
dropArea.addEventListener("dragenter", (e) => {
e.preventDefault();
dropArea.classList.add("border-blue-700");
});
dropArea.addEventListener("dragleave", (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
});
dropArea.addEventListener("dragover", (e) => {
e.preventDefault();
});
dropArea.addEventListener("drop", async (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
const url = e.dataTransfer.getData("text/uri-list");
const files = e.dataTransfer.files;
if (files.length > 0) {
const href = URL.createObjectURL(files[0]);
clearImageCanvas();
await drawImageCanvas(href);
} else if (url) {
clearImageCanvas();
await drawImageCanvas(url);
}
});
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (isRunning) {
stopRunning();
} else {
startRunning();
await generateSequence(runController);
stopRunning();
}
});
async function drawImageCanvas(imgURL) {
if (!imgURL) {
throw new Error("No image URL provided");
}
return new Promise((resolve, reject) => {
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
const img = new Image();
img.crossOrigin = "anonymous";
img.onload = () => {
canvas.width = img.width;
canvas.height = img.height;
ctxCanvas.drawImage(img, 0, 0);
clearImgBtn.disabled = false;
resolve(img);
};
img.src = imgURL;
currentImageURL = imgURL;
});
}
function clearImageCanvas() {
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
clearImgBtn.disabled = true;
canvas.parentElement.style.height = "auto";
currentImageURL = null;
canvas.width = 0;
canvas.height = 0;
}
function startRunning() {
isRunning = true;
runBtn.textContent = "Stop";
prompt.disabled = true;
}
function stopRunning() {
runController.abort();
runController = new AbortController();
runBtn.textContent = "Run";
isRunning = false;
prompt.disabled = false;
}
prompt.addEventListener("input", (e) => {
runBtn.disabled = false;
});

View File

@ -0,0 +1,312 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
<title>Candle Moondream Rust/WASM</title>
</head>
<body></body>
</html>
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/styles/default.min.css"
/>
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
body {
font-family: "Source Sans 3", sans-serif;
}
code,
output,
select,
pre {
font-family: "Source Code Pro", monospace;
}
</style>
<style type="text/tailwindcss">
.link {
@apply underline hover:text-blue-500 hover:no-underline;
}
</style>
<script src="https://cdn.tailwindcss.com/3.4.3"></script>
<script type="module" src="./code.js"></script>
</head>
<body class="container max-w-4xl mx-auto p-4 text-gray-800">
<main class="grid grid-cols-1 gap-8 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle Moondream 2</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
<a
href="https://huggingface.co/vikhyatk/moondream2"
class="link"
target="_blank"
>Moondream 2</a
>
by
<a
href=" https://huggingface.co/vikhyatk"
class="link"
target="_blank"
>Vik</a
>
and model implementation on Candle by
<a
href="https://huggingface.co/santiagomed"
class="link"
target="_blank"
>Santiago Medina
</a>
</p>
</div>
<div>
<p class="text-xs italic max-w-lg">
<b>Note:</b>
When first run, the app will download and cache the model, which could
take a few minutes. Then, the embeddings and generation will take a
few minutes to start 😔.
</p>
</div>
<div>
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light"
></select>
</div>
<form
id="form"
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
>
<input type="submit" hidden />
<input
type="text"
id="prompt"
class="font-light text-lg w-full px-3 py-2 mx-1 resize-none outline-none"
placeholder="Add your prompt here..."
/>
<button
id="run"
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
>
Run
</button>
</form>
<details>
<summary class="font-medium cursor-pointer">Advanced Options</summary>
<div class="grid grid-cols-3 max-w-md items-center gap-3 py-3">
<label class="text-sm font-medium" for="max-seq"
>Maximum length
</label>
<input
type="range"
id="max-seq"
name="max-seq"
min="1"
max="2048"
step="1"
value="500"
oninput="this.nextElementSibling.value = Number(this.value)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
500</output
>
<label class="text-sm font-medium" for="temperature"
>Temperature</label
>
<input
type="range"
id="temperature"
name="temperature"
min="0"
max="2"
step="0.01"
value="0.00"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
0.00</output
>
<label class="text-sm font-medium" for="top-p">Top-p</label>
<input
type="range"
id="top-p"
name="top-p"
min="0"
max="1"
step="0.01"
value="1.00"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
1.00</output
>
<label class="text-sm font-medium" for="repeat_penalty"
>Repeat Penalty</label
>
<input
type="range"
id="repeat_penalty"
name="repeat_penalty"
min="1"
max="2"
step="0.01"
value="1.10"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>1.10</output
>
<label class="text-sm font-medium" for="seed">Seed</label>
<input
type="number"
id="seed"
name="seed"
value="299792458"
class="font-light border border-gray-700 text-right rounded-md p-2"
/>
<button
id="run"
onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm"
>
Rand
</button>
</div>
</details>
<div class="grid md:grid-cols-2 gap-4 items-start">
<div>
<div class="relative md:mt-6">
<div
class="absolute w-full bottom-full flex justify-between items-center"
>
<div class="flex gap-2 w-full">
<button
id="clear-img-btn"
disabled
title="Clear Image"
class="ml-auto text-xs py-1 bg-white rounded-md disabled:opacity-20 flex gap-1 items-center"
>
<svg
class=""
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 13 12"
height="1em"
>
<path
d="M1.6.7 12 11.1M12 .7 1.6 11.1"
stroke="#2E3036"
stroke-width="2"
/>
</svg>
</button>
</div>
</div>
<div
id="drop-area"
class="min-h-[250px] flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative w-full overflow-hidden"
>
<div
class="absolute flex flex-col items-center justify-center space-y-1 text-center"
>
<svg
width="25"
height="25"
viewBox="0 0 25 25"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
fill="#000"
/>
</svg>
<div class="flex text-sm text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
>
<span>Drag and drop the image here</span>
<span class="block text-xs">or</span>
<span class="block text-xs">Click to upload</span>
</label>
</div>
<input
id="file-upload"
name="file-upload"
type="file"
accept="image/*"
class="sr-only"
/>
</div>
<canvas
id="canvas"
class="z-10 pointer-events-none w-full"
></canvas>
</div>
</div>
</div>
<div>
<h3 class="font-medium">Generation:</h3>
<div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
>
<div
id="output-counter"
hidden
class="ml-auto font-semibold grid-rows-1"
></div>
<p hidden id="output-generation" class="grid-rows-2 text-lg"></p>
<span id="output-status" class="m-auto font-light"
>No output yet</span
>
</div>
</div>
</div>
<div>
<div
class="flex gap-3 items-center overflow-x-scroll"
id="image-select"
>
<h3 class="font-medium">Examples:</h3>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
class="cursor-pointer w-24 h-24 object-cover"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
class="cursor-pointer w-24 h-24 object-cover"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
class="cursor-pointer w-24 h-24 object-cover"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/demo-1.jpg"
class="cursor-pointer w-24 h-24 object-cover"
/>
</div>
</div>
</main>
</body>
</html>

View File

@ -0,0 +1,201 @@
import init, { Model } from "./build/m.js";
async function fetchArrayBuffer(url, cacheModel = true) {
if (!cacheModel)
return new Uint8Array(await (await fetch(url)).arrayBuffer());
const cacheName = "moondream-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
async function concatenateArrayBuffers(urls) {
const arrayBuffers = await Promise.all(
urls.map((url) => fetchArrayBuffer(url))
);
let totalLength = arrayBuffers.reduce(
(acc, arrayBuffer) => acc + arrayBuffer.byteLength,
0
);
let concatenatedBuffer = new Uint8Array(totalLength);
let offset = 0;
arrayBuffers.forEach((buffer) => {
concatenatedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
});
return concatenatedBuffer;
}
class Moondream {
static imageArrayHash = {};
static instance = {};
static currentModelID = null;
static async getInstance(weightsURL, modelID, tokenizerURL, quantized) {
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
weightsURL instanceof Array
? concatenateArrayBuffers(weightsURL)
: fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
]);
this.instance[modelID] = new Model(
weightsArrayU8,
tokenizerArrayU8,
quantized
);
}
this.currentModelID = modelID;
return this.instance[modelID];
}
// Remove the modelID parameter from setImageEmbeddings
static setImageEmbeddings(imageArrayU8) {
// check if image embeddings are already set for this image and model
const imageArrayHash = this.getSimpleHash(imageArrayU8);
if (
this.imageArrayHash[this.currentModelID] === imageArrayHash &&
this.instance[this.currentModelID]
) {
self.postMessage({
status: "embedding",
message: "Embeddings Already Set",
});
return;
}
this.imageArrayHash[this.currentModelID] = imageArrayHash;
this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
self.postMessage({ status: "embedding", message: "Embeddings Set" });
}
static getSimpleHash(imageArrayU8) {
// get simple hash of imageArrayU8
let imageArrayHash = 0;
for (let i = 0; i < imageArrayU8.length; i += 100) {
imageArrayHash ^= imageArrayU8[i];
}
return imageArrayHash.toString(16);
}
}
let controller = null;
self.addEventListener("message", (event) => {
if (event.data.command === "start") {
controller = new AbortController();
generate(event.data);
} else if (event.data.command === "abort") {
controller.abort();
}
});
async function generate(data) {
const {
weightsURL,
modelID,
tokenizerURL,
quantized,
imageURL,
prompt,
seed,
temp,
top_p,
repeatPenalty,
maxSeqLen,
verbose_prompt,
} = data;
try {
self.postMessage({ status: "loading", message: "Starting Moondream" });
const model = await Moondream.getInstance(
weightsURL,
modelID,
tokenizerURL,
quantized
);
self.postMessage({ status: "loading", message: "Initializing model" });
self.postMessage({ status: "loading", message: "Loading Image" });
const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
self.postMessage({ status: "embedding", message: "Creating Embeddings" });
Moondream.setImageEmbeddings(imageArrayU8);
self.postMessage({
status: "complete-embedding",
message: "Embeddings Complete",
});
const { token, token_id } = model.init_with_image_prompt({
prompt,
seed: BigInt(seed),
temp: parseFloat(temp),
top_p: parseFloat(top_p),
repeat_penalty: parseFloat(repeatPenalty),
repeat_last_n: 64,
verbose_prompt,
});
const seq_len = 2048;
let sentence = token;
let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
let startTime = performance.now();
let tokensCount = 0;
while (tokensCount < maxTokens) {
await new Promise(async (resolve) => {
if (controller && controller.signal.aborted) {
console.log("Aborted");
self.postMessage({
status: "aborted",
message: "Aborted",
output: prompt + sentence,
});
return;
}
const { token, token_id } = await model.next_token();
if (token_id === 50256) {
// <|endoftext|>
self.postMessage({
status: "complete",
message: "complete",
output: prompt + sentence,
});
return;
}
const tokensSec =
((tokensCount + 1) / (performance.now() - startTime)) * 1000;
sentence += token;
self.postMessage({
status: "generating",
message: "Generating token",
token: token,
sentence: sentence,
totalTime: performance.now() - startTime,
tokensSec,
prompt: prompt,
});
setTimeout(resolve, 0);
});
tokensCount++;
}
self.postMessage({
status: "complete",
message: "complete",
output: prompt + sentence,
});
} catch (e) {
self.postMessage({ error: e });
}
}

View File

@ -0,0 +1,279 @@
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
models::{moondream, quantized_moondream},
};
use candle_wasm_example_moondream::console_log;
use js_sys::Date;
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
enum SelectedModel {
Moondream(moondream::Model),
Quantized(quantized_moondream::Model),
}
#[wasm_bindgen]
pub struct Model {
model: SelectedModel,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
tokens: Vec<u32>,
repeat_penalty: f32,
repeat_last_n: usize,
index: usize,
bos_token: Option<Tensor>,
image_embeddings: Option<Tensor>,
}
#[derive(Serialize, Deserialize)]
struct Output {
token: String,
token_id: u32,
}
#[derive(Serialize, Deserialize)]
struct InitInput {
prompt: String,
seed: u64,
temp: f64,
top_p: f64,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, quantized: bool) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = Device::Cpu;
let config = moondream::Config::v2();
console_log!("config loaded in {:?}", Date::now());
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let start = Date::now();
console_log!("weights len: {:?}", weights.len());
let model = if quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&weights, &device,
)?;
console_log!("weights loaded");
let model = quantized_moondream::Model::new(&config, vb)?;
SelectedModel::Quantized(model)
} else {
let device = &Device::Cpu;
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
let model = moondream::Model::new(&config, vb)?;
SelectedModel::Moondream(model)
};
console_log!("model loaded in {:?}s", (Date::now() - start) / 1000.);
let logits_processor = LogitsProcessor::new(299792458, None, None);
Ok(Self {
model,
tokenizer,
tokens: vec![],
logits_processor,
repeat_penalty: 1.,
repeat_last_n: 64,
bos_token: None,
image_embeddings: None,
index: 0,
})
}
pub fn set_image_embeddings(&mut self, image: Vec<u8>) -> Result<(), JsError> {
let device = Device::Cpu;
console_log!("loading image as tensor");
let start = Date::now();
let image: Tensor = self.load_image(image)?.to_device(&device)?;
console_log!("image loaded in {:?}s", (Date::now() - start) / 1000.);
let start = Date::now();
let image_embeds = &image.unsqueeze(0)?;
let image_embeds = match &self.model {
SelectedModel::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
SelectedModel::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
};
console_log!(
"loaded and encoded the image {image:?} in {:?}",
(Date::now() - start) / 1000.
);
self.image_embeddings = Some(image_embeds);
Ok(())
}
#[wasm_bindgen]
pub fn init_with_image_prompt(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let InitInput {
prompt,
seed,
temp,
top_p,
repeat_penalty,
repeat_last_n,
verbose_prompt,
} = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let device = Device::Cpu;
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", prompt);
match &mut self.model {
SelectedModel::Moondream(m) => m.text_model.clear_kv_cache(),
SelectedModel::Quantized(m) => m.text_model.clear_kv_cache(),
};
let temp = if temp <= 0. { None } else { Some(temp) };
let top_p = if top_p <= 0. || top_p >= 1. {
None
} else {
Some(top_p)
};
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
self.repeat_penalty = repeat_penalty;
self.repeat_last_n = repeat_last_n;
self.tokens.clear();
self.index = 0;
// Moondream tokenizer bos_token is "<|endoftext|>"
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => return Err(JsError::new("BOS token not found in the tokenizer.")),
};
self.bos_token = Some(Tensor::new(&[special_token], &device)?.unsqueeze(0)?);
let tokens = self
.tokenizer
.encode(prompt, true)
.map_err(|m| JsError::new(&m.to_string()))?;
if tokens.is_empty() {
return Err(JsError::new(
"Empty prompts are not supported in the Moondream model.",
));
}
if verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let tokens = tokens.get_ids().to_vec();
let text = match self.process(&tokens) {
Ok(text) => text,
Err(_e) => {
console_log!("error decoding token");
Output {
token: "".to_string(),
token_id: 0,
}
}
};
Ok(serde_wasm_bindgen::to_value(&text)?)
}
#[wasm_bindgen]
pub fn next_token(&mut self) -> Result<JsValue, JsError> {
let last_token = *self.tokens.last().unwrap();
let text = match self.process(&[last_token]) {
Ok(text) => text,
Err(_e) => {
console_log!("error decoding token");
Output {
token: "".to_string(),
token_id: 0,
}
}
};
Ok(serde_wasm_bindgen::to_value(&text)?)
}
}
impl Model {
fn load_image(&self, image: Vec<u8>) -> Result<Tensor, JsError> {
let img = image::io::Reader::new(std::io::Cursor::new(image))
.with_guessed_format()?
.decode()
.map_err(|e| JsError::new(&e.to_string()))?
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
.map_err(|e| JsError::new(&e.to_string()))
}
}
impl Model {
fn process(&mut self, tokens: &[u32]) -> Result<Output, JsError> {
let image_embeddings = match &self.image_embeddings {
Some(embeddings) => embeddings,
None => return Err(JsError::new("Image embeddings are not set.")),
};
let bos_token = match &self.bos_token {
Some(token) => token,
None => return Err(JsError::new("BOS token is not set.")),
};
let device = Device::Cpu;
let context_size = if self.index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = if self.index > 0 {
match self.model {
SelectedModel::Moondream(ref mut model) => model.text_model.forward(&input)?,
SelectedModel::Quantized(ref mut model) => model.text_model.forward(&input)?,
}
} else {
match self.model {
SelectedModel::Moondream(ref mut model) => {
model
.text_model
.forward_with_img(bos_token, &input, image_embeddings)?
}
SelectedModel::Quantized(ref mut model) => {
model
.text_model
.forward_with_img(bos_token, &input, image_embeddings)?
}
}
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
self.tokens.push(next_token);
let token = match self.tokenizer.decode(&[next_token], true) {
Ok(token) => token,
Err(e) => {
console_log!("error decoding token: {:?}", e);
"".to_string()
}
};
self.index += 1;
Ok(Output {
token,
token_id: next_token,
})
}
}
fn main() {
console_error_panic_hook::set_once();
}

View File

@ -0,0 +1,16 @@
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
}