mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Llama2c WASM UI improvements (#732)
* pass seed, expose model seq_len * wip new llama2.c ui * final new UI example * small coppy * copy
This commit is contained in:
47
candle-wasm-examples/llama2-c/README.md
Normal file
47
candle-wasm-examples/llama2-c/README.md
Normal file
@ -0,0 +1,47 @@
|
||||
## Running [llama2.c](https://github.com/karpathy/llama2.c) Examples
|
||||
|
||||
Here, we provide two examples of how to run [llama2.c](https://github.com/karpathy/llama2.c) written in Rust using a Candle-compiled WASM binary and runtimes.
|
||||
|
||||
### Pure Rust UI
|
||||
|
||||
To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
|
||||
From the `candle-wasm-examples/llama2-c` directory run:
|
||||
|
||||
Download assets:
|
||||
|
||||
```bash
|
||||
# Model and tokenizer
|
||||
|
||||
wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
|
||||
wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
|
||||
|
||||
```
|
||||
|
||||
Run hot reload server:
|
||||
|
||||
```bash
|
||||
trunk serve --release --public-url / --port 8080
|
||||
```
|
||||
|
||||
### Vanilla JS and WebWorkers
|
||||
|
||||
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
|
||||
|
||||
```bash
|
||||
sh build-lib.sh
|
||||
```
|
||||
|
||||
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
|
||||
|
||||
```js
|
||||
import init, { Model } from "./build/m.js";
|
||||
```
|
||||
|
||||
The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
|
||||
Finally, you can preview the example by running a local HTTP server. For example:
|
||||
|
||||
```bash
|
||||
python -m http.server
|
||||
```
|
||||
|
||||
Then open `http://localhost:8000/lib-example.html` in your browser.
|
2
candle-wasm-examples/llama2-c/build-lib.sh
Normal file
2
candle-wasm-examples/llama2-c/build-lib.sh
Normal file
@ -0,0 +1,2 @@
|
||||
cargo build --target wasm32-unknown-unknown --release
|
||||
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
|
311
candle-wasm-examples/llama2-c/lib-example.html
Normal file
311
candle-wasm-examples/llama2-c/lib-example.html
Normal file
@ -0,0 +1,311 @@
|
||||
<html>
|
||||
<head>
|
||||
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
|
||||
<title>Candle Llama.c Rust/WASM</title>
|
||||
</head>
|
||||
<body></body>
|
||||
</html>
|
||||
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<style>
|
||||
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||
html,
|
||||
body {
|
||||
font-family: "Source Sans 3", sans-serif;
|
||||
}
|
||||
code,
|
||||
output,
|
||||
select,
|
||||
pre {
|
||||
font-family: "Source Code Pro", monospace;
|
||||
}
|
||||
</style>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script type="module">
|
||||
// base url for audio examples
|
||||
const MODELS_BASE_URL =
|
||||
"https://huggingface.co/karpathy/tinyllamas/resolve/main";
|
||||
|
||||
// models base url
|
||||
const MODELS = {
|
||||
stories15M: {
|
||||
url: "stories15M.bin",
|
||||
seq_len: 256,
|
||||
},
|
||||
stories42M: {
|
||||
url: "stories42M.bin",
|
||||
seq_len: 256,
|
||||
},
|
||||
stories110M: {
|
||||
url: "stories110M.bin",
|
||||
seq_len: 256,
|
||||
},
|
||||
};
|
||||
|
||||
const llamaWorker = new Worker("./llama2cWorker.js", {
|
||||
type: "module",
|
||||
});
|
||||
async function generateSequence(controller) {
|
||||
const getValue = (id) => document.querySelector(`#${id}`).value;
|
||||
const modelID = getValue("model");
|
||||
const model = MODELS[modelID];
|
||||
const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
|
||||
const prompt = getValue("prompt");
|
||||
const temperature = getValue("temperature");
|
||||
const repeatPenalty = getValue("repeat_penalty");
|
||||
const seed = getValue("seed");
|
||||
const maxSeqLen = getValue("max-seq");
|
||||
|
||||
function updateStatus({ status, message, prompt, sentence }) {
|
||||
const outStatus = document.querySelector("#output-status");
|
||||
const outGen = document.querySelector("#output-generation");
|
||||
|
||||
switch (status) {
|
||||
case "loading":
|
||||
outStatus.hidden = false;
|
||||
outStatus.textContent = message;
|
||||
outGen.hidden = true;
|
||||
break;
|
||||
case "generating":
|
||||
outStatus.hidden = true;
|
||||
outGen.hidden = false;
|
||||
outGen.innerHTML = `<span class="font-semibold">${prompt}</span>${sentence.replace(
|
||||
/\<s\>|\<\/s\>/g,
|
||||
""
|
||||
)}`;
|
||||
break;
|
||||
case "complete":
|
||||
outStatus.hidden = true;
|
||||
outGen.hidden = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
llamaWorker.postMessage({
|
||||
weightsURL,
|
||||
modelID,
|
||||
tokenizerURL: "tokenizer.json",
|
||||
prompt,
|
||||
temp: temperature,
|
||||
repeatPenalty,
|
||||
seed: BigInt(seed),
|
||||
maxSeqLen,
|
||||
command: "start",
|
||||
});
|
||||
|
||||
const handleAbort = () => {
|
||||
llamaWorker.postMessage({ command: "abort" });
|
||||
};
|
||||
const handleMessage = (event) => {
|
||||
const { status, error, message, prompt, sentence } = event.data;
|
||||
if (status) updateStatus(event.data);
|
||||
if (error) reject(new Error(error));
|
||||
if (status === "complete") resolve(event.data);
|
||||
};
|
||||
|
||||
controller.signal.addEventListener("abort", handleAbort);
|
||||
llamaWorker.addEventListener("message", handleMessage);
|
||||
});
|
||||
}
|
||||
|
||||
const form = document.querySelector("#form");
|
||||
const prompt = document.querySelector("#prompt");
|
||||
const clearBtn = document.querySelector("#clear-btn");
|
||||
const runBtn = document.querySelector("#run");
|
||||
let runController = new AbortController();
|
||||
let isRunning = false;
|
||||
|
||||
form.addEventListener("submit", async (e) => {
|
||||
e.preventDefault();
|
||||
if (isRunning) {
|
||||
stopRunning();
|
||||
} else {
|
||||
startRunning();
|
||||
await generateSequence(runController);
|
||||
stopRunning();
|
||||
}
|
||||
});
|
||||
|
||||
function startRunning() {
|
||||
isRunning = true;
|
||||
runBtn.textContent = "Stop";
|
||||
}
|
||||
|
||||
function stopRunning() {
|
||||
runController.abort();
|
||||
runController = new AbortController();
|
||||
runBtn.textContent = "Run";
|
||||
isRunning = false;
|
||||
}
|
||||
clearBtn.addEventListener("click", (e) => {
|
||||
e.preventDefault();
|
||||
prompt.value = "";
|
||||
clearBtn.classList.add("invisible");
|
||||
runBtn.disabled = true;
|
||||
stopRunning();
|
||||
});
|
||||
prompt.addEventListener("input", (e) => {
|
||||
runBtn.disabled = false;
|
||||
if (e.target.value.length > 0) {
|
||||
clearBtn.classList.remove("invisible");
|
||||
} else {
|
||||
clearBtn.classList.add("invisible");
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</head>
|
||||
<body class="container max-w-4xl mx-auto p-4 text-gray-800">
|
||||
<main class="grid grid-cols-1 gap-8 relative">
|
||||
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
|
||||
<div>
|
||||
<h1 class="text-5xl font-bold">Candle Llama2.c</h1>
|
||||
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
|
||||
<p class="max-w-lg">
|
||||
<a
|
||||
href="https://github.com/karpathy/llama2.c"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
target="_blank"
|
||||
>Llama2.c</a
|
||||
>
|
||||
is Andrey Karpathy's C implementation of the Llama 2 LLM model in C.
|
||||
This demo uses
|
||||
<a
|
||||
href="https://github.com/huggingface/candle/"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>Candle
|
||||
</a>
|
||||
to run Llama2.c in the browser using rust/wasm.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="model" class="font-medium">Models Options: </label>
|
||||
<select
|
||||
id="model"
|
||||
class="border-2 border-gray-500 rounded-md font-light"
|
||||
>
|
||||
<option value="stories15M" selected>stories 15M (60.8 MB)</option>
|
||||
<option value="stories42M">stories 42M (167 MB)</option>
|
||||
<option value="stories110M">stories 110M (438 MB)</option>
|
||||
</select>
|
||||
</div>
|
||||
<form
|
||||
id="form"
|
||||
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
|
||||
>
|
||||
<input type="submit" hidden />
|
||||
<input
|
||||
type="text"
|
||||
id="prompt"
|
||||
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
|
||||
placeholder="Add your prompt here..."
|
||||
/>
|
||||
<button class="invisible" id="clear-btn">
|
||||
<svg
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="40"
|
||||
viewBox="0 0 70 40"
|
||||
>
|
||||
<path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
|
||||
<path
|
||||
d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
|
||||
opacity=".5"
|
||||
stroke="#1F2937"
|
||||
stroke-width="2"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
id="run"
|
||||
disabled
|
||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
|
||||
>
|
||||
Run
|
||||
</button>
|
||||
</form>
|
||||
<div class="grid grid-cols-3 max-w-md items-center gap-3">
|
||||
<label class="text-sm font-medium" for="max-seq">Maximum length </label>
|
||||
<input
|
||||
type="range"
|
||||
id="max-seq"
|
||||
name="temperature"
|
||||
min="1"
|
||||
max="256"
|
||||
step="1"
|
||||
value="200"
|
||||
oninput="this.nextElementSibling.value = Number(this.value)"
|
||||
/>
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||
>
|
||||
200</output
|
||||
>
|
||||
<label class="text-sm font-medium" for="temperature">Temperature</label>
|
||||
<input
|
||||
type="range"
|
||||
id="temperature"
|
||||
name="temperature"
|
||||
min="0"
|
||||
max="2"
|
||||
step="0.01"
|
||||
value="0.50"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||
/>
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||
>
|
||||
0.50</output
|
||||
>
|
||||
|
||||
<label class="text-sm font-medium" for="repeat_penalty"
|
||||
>Repeat Penalty</label
|
||||
>
|
||||
|
||||
<input
|
||||
type="range"
|
||||
id="repeat_penalty"
|
||||
name="repeat_penalty"
|
||||
min="-2"
|
||||
max="2"
|
||||
step="0.01"
|
||||
value="1.10"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||
/>
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||
>1.10</output
|
||||
>
|
||||
<label class="text-sm font-medium" for="seed">Seed</label>
|
||||
<input
|
||||
type="number"
|
||||
id="seed"
|
||||
name="seed"
|
||||
value="299792458"
|
||||
class="font-light border border-gray-700 text-right rounded-md p-2"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="font-medium">Generation:</h3>
|
||||
|
||||
<div
|
||||
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md grid"
|
||||
>
|
||||
<p hidden id="output-generation"></p>
|
||||
<span
|
||||
id="output-status"
|
||||
class="justify-self-center self-center font-light"
|
||||
>No output yet</span
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
96
candle-wasm-examples/llama2-c/llama2cWorker.js
Normal file
96
candle-wasm-examples/llama2-c/llama2cWorker.js
Normal file
@ -0,0 +1,96 @@
|
||||
import init, { Model } from "./build/m.js";
|
||||
|
||||
async function fetchArrayBuffer(url) {
|
||||
const res = await fetch(url, {
|
||||
cache: "force-cache",
|
||||
});
|
||||
const data = await res.arrayBuffer();
|
||||
return new Uint8Array(data);
|
||||
}
|
||||
|
||||
class Llama2C {
|
||||
static instance = {};
|
||||
|
||||
static async getInstance(weightsURL, modelID, tokenizerURL) {
|
||||
// load individual modelID only once
|
||||
if (!this.instance[modelID]) {
|
||||
await init();
|
||||
|
||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||
|
||||
const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
|
||||
fetchArrayBuffer(weightsURL),
|
||||
fetchArrayBuffer(tokenizerURL),
|
||||
]);
|
||||
|
||||
this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8);
|
||||
}
|
||||
return this.instance[modelID];
|
||||
}
|
||||
}
|
||||
|
||||
let controller = null;
|
||||
self.addEventListener("message", (event) => {
|
||||
if (event.data.command === "start") {
|
||||
controller = new AbortController();
|
||||
generate(event.data);
|
||||
} else if (event.data.command === "abort") {
|
||||
controller.abort();
|
||||
}
|
||||
});
|
||||
|
||||
async function generate(data) {
|
||||
const {
|
||||
weightsURL,
|
||||
modelID,
|
||||
tokenizerURL,
|
||||
prompt,
|
||||
temp,
|
||||
repeatPenalty,
|
||||
seed,
|
||||
maxSeqLen,
|
||||
} = data;
|
||||
try {
|
||||
self.postMessage({ status: "loading", message: "Starting llama2.c" });
|
||||
const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
|
||||
|
||||
self.postMessage({ status: "loading", message: "Initializing model" });
|
||||
model.init_with_prompt(prompt, temp, repeatPenalty, seed);
|
||||
|
||||
const seq_len = model.get_seq_len();
|
||||
|
||||
let sentence = "";
|
||||
let max_tokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
|
||||
|
||||
while (max_tokens--) {
|
||||
await new Promise(async (resolve) => {
|
||||
if (controller && controller.signal.aborted) {
|
||||
self.postMessage({
|
||||
status: "aborted",
|
||||
message: "Aborted",
|
||||
output: prompt + sentence,
|
||||
});
|
||||
return;
|
||||
}
|
||||
const token = await model.next_token();
|
||||
|
||||
sentence += token;
|
||||
self.postMessage({
|
||||
status: "generating",
|
||||
message: "Generating token",
|
||||
token: token,
|
||||
sentence: sentence,
|
||||
prompt: prompt,
|
||||
});
|
||||
setTimeout(resolve, 0);
|
||||
});
|
||||
}
|
||||
self.postMessage({
|
||||
status: "complete",
|
||||
message: "complete",
|
||||
output: prompt + sentence,
|
||||
});
|
||||
} catch (e) {
|
||||
self.postMessage({ error: e });
|
||||
}
|
||||
}
|
@ -58,6 +58,11 @@ impl Model {
|
||||
Err(e) => Err(JsError::new(&e.to_string())),
|
||||
}
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
pub fn get_seq_len(&mut self) -> usize {
|
||||
let seq_len = self.inner.config.seq_len;
|
||||
seq_len
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn init_with_prompt(
|
||||
@ -65,6 +70,7 @@ impl Model {
|
||||
prompt: String,
|
||||
temp: f64,
|
||||
repeat_penalty: f32,
|
||||
seed: u64,
|
||||
) -> Result<String, JsError> {
|
||||
// First reset the cache.
|
||||
{
|
||||
@ -74,7 +80,7 @@ impl Model {
|
||||
}
|
||||
}
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
self.logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp);
|
||||
self.repeat_penalty = repeat_penalty;
|
||||
self.tokens.clear();
|
||||
let tokens = self
|
||||
|
@ -51,7 +51,7 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
|
||||
|
||||
pub struct Model {
|
||||
pub cache: Cache,
|
||||
config: Config,
|
||||
pub config: Config,
|
||||
pub llama: Llama,
|
||||
pub tokenizer: Tokenizer,
|
||||
}
|
||||
|
Reference in New Issue
Block a user