Llama2c WASM UI improvements (#732)

* pass seed, expose model seq_len

* wip new llama2.c ui

* final new UI example

* small coppy

* copy
This commit is contained in:
Radamés Ajna
2023-09-04 07:59:22 -07:00
committed by GitHub
parent e2f9f60ac2
commit 8395152d20
6 changed files with 464 additions and 2 deletions

View File

@ -0,0 +1,47 @@
## Running [llama2.c](https://github.com/karpathy/llama2.c) Examples
Here, we provide two examples of how to run [llama2.c](https://github.com/karpathy/llama2.c) written in Rust using a Candle-compiled WASM binary and runtimes.
### Pure Rust UI
To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
From the `candle-wasm-examples/llama2-c` directory run:
Download assets:
```bash
# Model and tokenizer
wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
```
Run hot reload server:
```bash
trunk serve --release --public-url / --port 8080
```
### Vanilla JS and WebWorkers
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
```bash
sh build-lib.sh
```
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
```js
import init, { Model } from "./build/m.js";
```
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
Finally, you can preview the example by running a local HTTP server. For example:
```bash
python -m http.server
```
Then open `http://localhost:8000/lib-example.html` in your browser.

View File

@ -0,0 +1,2 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web

View File

@ -0,0 +1,311 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
<title>Candle Llama.c Rust/WASM</title>
</head>
<body></body>
</html>
<!doctype html>
<html>
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
body {
font-family: "Source Sans 3", sans-serif;
}
code,
output,
select,
pre {
font-family: "Source Code Pro", monospace;
}
</style>
<script src="https://cdn.tailwindcss.com"></script>
<script type="module">
// base url for audio examples
const MODELS_BASE_URL =
"https://huggingface.co/karpathy/tinyllamas/resolve/main";
// models base url
const MODELS = {
stories15M: {
url: "stories15M.bin",
seq_len: 256,
},
stories42M: {
url: "stories42M.bin",
seq_len: 256,
},
stories110M: {
url: "stories110M.bin",
seq_len: 256,
},
};
const llamaWorker = new Worker("./llama2cWorker.js", {
type: "module",
});
async function generateSequence(controller) {
const getValue = (id) => document.querySelector(`#${id}`).value;
const modelID = getValue("model");
const model = MODELS[modelID];
const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
const prompt = getValue("prompt");
const temperature = getValue("temperature");
const repeatPenalty = getValue("repeat_penalty");
const seed = getValue("seed");
const maxSeqLen = getValue("max-seq");
function updateStatus({ status, message, prompt, sentence }) {
const outStatus = document.querySelector("#output-status");
const outGen = document.querySelector("#output-generation");
switch (status) {
case "loading":
outStatus.hidden = false;
outStatus.textContent = message;
outGen.hidden = true;
break;
case "generating":
outStatus.hidden = true;
outGen.hidden = false;
outGen.innerHTML = `<span class="font-semibold">${prompt}</span>${sentence.replace(
/\<s\>|\<\/s\>/g,
""
)}`;
break;
case "complete":
outStatus.hidden = true;
outGen.hidden = false;
break;
}
}
return new Promise((resolve, reject) => {
llamaWorker.postMessage({
weightsURL,
modelID,
tokenizerURL: "tokenizer.json",
prompt,
temp: temperature,
repeatPenalty,
seed: BigInt(seed),
maxSeqLen,
command: "start",
});
const handleAbort = () => {
llamaWorker.postMessage({ command: "abort" });
};
const handleMessage = (event) => {
const { status, error, message, prompt, sentence } = event.data;
if (status) updateStatus(event.data);
if (error) reject(new Error(error));
if (status === "complete") resolve(event.data);
};
controller.signal.addEventListener("abort", handleAbort);
llamaWorker.addEventListener("message", handleMessage);
});
}
const form = document.querySelector("#form");
const prompt = document.querySelector("#prompt");
const clearBtn = document.querySelector("#clear-btn");
const runBtn = document.querySelector("#run");
let runController = new AbortController();
let isRunning = false;
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (isRunning) {
stopRunning();
} else {
startRunning();
await generateSequence(runController);
stopRunning();
}
});
function startRunning() {
isRunning = true;
runBtn.textContent = "Stop";
}
function stopRunning() {
runController.abort();
runController = new AbortController();
runBtn.textContent = "Run";
isRunning = false;
}
clearBtn.addEventListener("click", (e) => {
e.preventDefault();
prompt.value = "";
clearBtn.classList.add("invisible");
runBtn.disabled = true;
stopRunning();
});
prompt.addEventListener("input", (e) => {
runBtn.disabled = false;
if (e.target.value.length > 0) {
clearBtn.classList.remove("invisible");
} else {
clearBtn.classList.add("invisible");
}
});
</script>
</head>
<body class="container max-w-4xl mx-auto p-4 text-gray-800">
<main class="grid grid-cols-1 gap-8 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle Llama2.c</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
<a
href="https://github.com/karpathy/llama2.c"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
target="_blank"
>Llama2.c</a
>
is Andrey Karpathy's C implementation of the Llama 2 LLM model in C.
This demo uses
<a
href="https://github.com/huggingface/candle/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>Candle
</a>
to run Llama2.c in the browser using rust/wasm.
</p>
</div>
<div>
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light"
>
<option value="stories15M" selected>stories 15M (60.8 MB)</option>
<option value="stories42M">stories 42M (167 MB)</option>
<option value="stories110M">stories 110M (438 MB)</option>
</select>
</div>
<form
id="form"
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
>
<input type="submit" hidden />
<input
type="text"
id="prompt"
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
placeholder="Add your prompt here..."
/>
<button class="invisible" id="clear-btn">
<svg
fill="none"
xmlns="http://www.w3.org/2000/svg"
width="40"
viewBox="0 0 70 40"
>
<path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
<path
d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
opacity=".5"
stroke="#1F2937"
stroke-width="2"
/>
</svg>
</button>
<button
id="run"
disabled
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
>
Run
</button>
</form>
<div class="grid grid-cols-3 max-w-md items-center gap-3">
<label class="text-sm font-medium" for="max-seq">Maximum length </label>
<input
type="range"
id="max-seq"
name="temperature"
min="1"
max="256"
step="1"
value="200"
oninput="this.nextElementSibling.value = Number(this.value)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
200</output
>
<label class="text-sm font-medium" for="temperature">Temperature</label>
<input
type="range"
id="temperature"
name="temperature"
min="0"
max="2"
step="0.01"
value="0.50"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
0.50</output
>
<label class="text-sm font-medium" for="repeat_penalty"
>Repeat Penalty</label
>
<input
type="range"
id="repeat_penalty"
name="repeat_penalty"
min="-2"
max="2"
step="0.01"
value="1.10"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>1.10</output
>
<label class="text-sm font-medium" for="seed">Seed</label>
<input
type="number"
id="seed"
name="seed"
value="299792458"
class="font-light border border-gray-700 text-right rounded-md p-2"
/>
</div>
<div>
<h3 class="font-medium">Generation:</h3>
<div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md grid"
>
<p hidden id="output-generation"></p>
<span
id="output-status"
class="justify-self-center self-center font-light"
>No output yet</span
>
</div>
</div>
</main>
</body>
</html>

View File

@ -0,0 +1,96 @@
import init, { Model } from "./build/m.js";
async function fetchArrayBuffer(url) {
const res = await fetch(url, {
cache: "force-cache",
});
const data = await res.arrayBuffer();
return new Uint8Array(data);
}
class Llama2C {
static instance = {};
static async getInstance(weightsURL, modelID, tokenizerURL) {
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
]);
this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8);
}
return this.instance[modelID];
}
}
let controller = null;
self.addEventListener("message", (event) => {
if (event.data.command === "start") {
controller = new AbortController();
generate(event.data);
} else if (event.data.command === "abort") {
controller.abort();
}
});
async function generate(data) {
const {
weightsURL,
modelID,
tokenizerURL,
prompt,
temp,
repeatPenalty,
seed,
maxSeqLen,
} = data;
try {
self.postMessage({ status: "loading", message: "Starting llama2.c" });
const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
self.postMessage({ status: "loading", message: "Initializing model" });
model.init_with_prompt(prompt, temp, repeatPenalty, seed);
const seq_len = model.get_seq_len();
let sentence = "";
let max_tokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
while (max_tokens--) {
await new Promise(async (resolve) => {
if (controller && controller.signal.aborted) {
self.postMessage({
status: "aborted",
message: "Aborted",
output: prompt + sentence,
});
return;
}
const token = await model.next_token();
sentence += token;
self.postMessage({
status: "generating",
message: "Generating token",
token: token,
sentence: sentence,
prompt: prompt,
});
setTimeout(resolve, 0);
});
}
self.postMessage({
status: "complete",
message: "complete",
output: prompt + sentence,
});
} catch (e) {
self.postMessage({ error: e });
}
}

View File

@ -58,6 +58,11 @@ impl Model {
Err(e) => Err(JsError::new(&e.to_string())),
}
}
#[wasm_bindgen]
pub fn get_seq_len(&mut self) -> usize {
let seq_len = self.inner.config.seq_len;
seq_len
}
#[wasm_bindgen]
pub fn init_with_prompt(
@ -65,6 +70,7 @@ impl Model {
prompt: String,
temp: f64,
repeat_penalty: f32,
seed: u64,
) -> Result<String, JsError> {
// First reset the cache.
{
@ -74,7 +80,7 @@ impl Model {
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
self.logits_processor = LogitsProcessor::new(299792458, temp);
self.logits_processor = LogitsProcessor::new(seed, temp);
self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self

View File

@ -51,7 +51,7 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
pub struct Model {
pub cache: Cache,
config: Config,
pub config: Config,
pub llama: Llama,
pub tokenizer: Tokenizer,
}