[Wasm] Add puffin phi model to wasm (#1166)

* load config from file, add puffin phi links

* format

* add prompt examples
This commit is contained in:
Radamés Ajna
2023-10-24 23:09:03 -07:00
committed by GitHub
parent 45dbe541bc
commit b6053b938b
4 changed files with 206 additions and 39 deletions

View File

@ -4,11 +4,12 @@ use crate::models::with_tracing::{linear, Embedding as E, Linear};
/// https://arxiv.org/abs/2309.05463
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
const MAX_SEQ_LEN: usize = 4096;
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) n_positions: usize,

View File

@ -13,7 +13,8 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/styles/default.min.css" />
href="https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/styles/default.min.css"
/>
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
@ -36,27 +37,110 @@
<script type="module">
import snarkdown from "https://cdn.skypack.dev/snarkdown";
import hljs from "https://cdn.skypack.dev/highlight.js";
const TOKENIZER_URL =
"https://huggingface.co/microsoft/phi-1_5/raw/main/tokenizer.json";
// models base url
const MODELS = {
phi_1_5_quantized: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-q4k.gguf",
tokenizer: "tokenizer.json",
config: "phi-1_5.json",
quantized: true,
seq_len: 2048,
size: "800 MB",
},
phi_1_5_quantized_2: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-q80.gguf",
tokenizer: "tokenizer.json",
config: "phi-1_5.json",
quantized: true,
seq_len: 2048,
size: "1.51 GB",
},
puffin_phi_v2_quantized: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-puffin-phi-v2-q4k.gguf",
tokenizer: "tokenizer-puffin-phi-v2.json",
config: "puffin-phi-v2.json",
quantized: true,
seq_len: 2048,
size: "798 MB",
},
puffin_phi_v2_quantized_2: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-puffin-phi-v2-q80.gguf",
tokenizer: "tokenizer-puffin-phi-v2.json",
config: "puffin-phi-v2.json",
quantized: true,
seq_len: 2048,
size: "1.50 GB",
},
};
const TEMPLATES = [
{
title: "Simple prompt",
prompt: `Sebastien is in London today, its the middle of July yet its raining, so Sebastien is feeling gloomy. He`,
},
{
title: "Think step by step",
prompt: `Suppose Alice originally had 3 apples, then Bob gave Alice 7 apples, then Alice gave Cook 5 apples, and then Tim gave Alice 3x the amount of apples Alice had. How many apples does Alice have now?
Lets think step by step.`,
},
{
title: "Explaing a code snippet",
prompt: `What does this script do?
\`\`\`python
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
s.listen(1)
conn, addr = s.accept()
print('Connected by', addr)
return conn.getsockname()[1]
\`\`\`
Lets think step by step.`,
},
{
title: "Question answering",
prompt: `What is the capital of France?
Answer:`,
},
{
title: "Chat mode",
prompt: `Alice: Can you tell me how to create a python application to go through all the files
in one directory where the files name DOES NOT end with '.json'?
Bob:`,
},
{
title: "Python code completion",
prompt: `"""write a python function called batch(function, list) which call function(x) for x in
list in parallel"""
Solution:`,
},
{
title: "Python Sample",
prompt: `"""Can you make sure those histograms appear side by side on the same plot:
\`\`\`python
plt.hist(intreps_retrained[0][1].view(64,-1).norm(dim=1).detach().cpu().numpy(), bins = 20)
plt.hist(intreps_pretrained[0][1].view(64,-1).norm(dim=1).detach().cpu().numpy(), bins = 20)
\`\`\`
"""`,
},
{
title: "Write a Twitter post",
prompt: `Write a twitter post for the discovery of gravitational wave.
Twitter Post:`,
},
{
title: "Write a review",
prompt: `Write a polite review complaining that the video game 'Random Game' was too badly optimized and it burned my laptop.
Very polite review:`,
},
];
const phiWorker = new Worker("./phiWorker.js", {
type: "module",
});
@ -65,6 +149,8 @@
const modelID = getValue("model");
const model = MODELS[modelID];
const weightsURL = model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const configURL = model.base_url + model.config;
const prompt = getValue("prompt").trim();
const temperature = getValue("temperature");
@ -107,7 +193,8 @@
phiWorker.postMessage({
weightsURL,
modelID,
tokenizerURL: TOKENIZER_URL,
tokenizerURL,
configURL,
quantized: model.quantized,
prompt,
temp: temperature,
@ -148,9 +235,42 @@
const clearBtn = document.querySelector("#clear-btn");
const runBtn = document.querySelector("#run");
const modelSelect = document.querySelector("#model");
const promptTemplates = document.querySelector("#prompt-templates");
let runController = new AbortController();
let isRunning = false;
document.addEventListener("DOMContentLoaded", () => {
for (const [id, model] of Object.entries(MODELS)) {
const option = document.createElement("option");
option.value = id;
option.innerText = `${id} (${model.size})`;
modelSelect.appendChild(option);
}
for (const [i, { title, prompt }] of TEMPLATES.entries()) {
const div = document.createElement("div");
const input = document.createElement("input");
input.type = "radio";
input.name = "task";
input.id = `templates-${i}`;
input.classList.add("font-light", "cursor-pointer");
input.value = prompt;
const label = document.createElement("label");
label.htmlFor = `templates-${i}`;
label.classList.add("cursor-pointer");
label.innerText = title;
div.appendChild(input);
div.appendChild(label);
promptTemplates.appendChild(div);
}
});
promptTemplates.addEventListener("change", (e) => {
const template = e.target.value;
prompt.value = template;
prompt.style.height = "auto";
prompt.style.height = prompt.scrollHeight + "px";
});
modelSelect.addEventListener("change", (e) => {
const model = MODELS[e.target.value];
document.querySelector("#max-seq").max = model.seq_len;
@ -217,10 +337,27 @@
<a
href="https://arxiv.org/pdf/2309.05463.pdf#page=8"
class="link"
target="_blank">
target="_blank"
>
technical report </a
>.
</p>
<p class="max-w-lg">
You can also try
<a
href="https://huggingface.co/teknium/Puffin-Phi-v2"
class="link"
target="_blank"
>Puffin-Phi V2
</a>
quantized version model, a fine-tuned version of Phi-1.5 on the
<a
href="https://huggingface.co/datasets/LDJnr/Puffin"
class="link"
target="_blank"
>Puffin dataset
</a>
</p>
</div>
<div>
<p class="text-xs italic max-w-lg">
@ -234,26 +371,25 @@
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light">
<option value="phi_1_5_quantized" selected>
phi 1.5 quantized q4k (800 MB)
</option>
<option value="phi_1_5_quantized_2">
phi 1.5 quantized q80 (1.51 GB)
</option>
<!-- <option value="phi_1_5">phi 1.5 (2.84 GB)</option> -->
</select>
class="border-2 border-gray-500 rounded-md font-light"
></select>
</div>
<div>
<h3 class="font-medium">Prompt Templates</h3>
<form id="prompt-templates" class="flex flex-col gap-1 my-2"></form>
</div>
<form
id="form"
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
>
<input type="submit" hidden />
<textarea
type="text"
id="prompt"
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'"
placeholder="Add your prompt here...">
placeholder="Add your prompt here..."
>
Write a detailed analogy between mathematics and a lighthouse.
Answer:</textarea
>
@ -262,18 +398,21 @@ Answer:</textarea
fill="none"
xmlns="http://www.w3.org/2000/svg"
width="40"
viewBox="0 0 70 40">
viewBox="0 0 70 40"
>
<path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
<path
d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
opacity=".5"
stroke="#1F2937"
stroke-width="2" />
stroke-width="2"
/>
</svg>
</button>
<button
id="run"
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
>
Run
</button>
</form>
@ -292,9 +431,11 @@ Answer:</textarea
max="2048"
step="1"
value="200"
oninput="this.nextElementSibling.value = Number(this.value)" />
oninput="this.nextElementSibling.value = Number(this.value)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
200</output
>
<label class="text-sm font-medium" for="temperature"
@ -308,9 +449,11 @@ Answer:</textarea
max="2"
step="0.01"
value="0.00"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
0.00</output
>
<label class="text-sm font-medium" for="top-p">Top-p</label>
@ -322,9 +465,11 @@ Answer:</textarea
max="1"
step="0.01"
value="1.00"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>
1.00</output
>
@ -340,7 +485,8 @@ Answer:</textarea
max="2"
step="0.01"
value="1.10"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
/>
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>1.10</output
@ -351,11 +497,13 @@ Answer:</textarea
id="seed"
name="seed"
value="299792458"
class="font-light border border-gray-700 text-right rounded-md p-2" />
class="font-light border border-gray-700 text-right rounded-md p-2"
/>
<button
id="run"
onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm"
>
Rand
</button>
</div>
@ -364,11 +512,13 @@ Answer:</textarea
<div>
<h3 class="font-medium">Generation:</h3>
<div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
>
<div
id="output-counter"
hidden
class="ml-auto font-semibold grid-rows-1 text-sm"></div>
class="ml-auto font-semibold grid-rows-1 text-sm"
></div>
<p hidden id="output-generation" class="grid-rows-2"></p>
<span id="output-status" class="m-auto font-light"
>No output yet</span

View File

@ -15,21 +15,30 @@ async function fetchArrayBuffer(url) {
class Phi {
static instance = {};
static async getInstance(weightsURL, modelID, tokenizerURL, quantized) {
static async getInstance(
weightsURL,
modelID,
tokenizerURL,
configURL,
quantized
) {
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
]);
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new Model(
weightsArrayU8,
tokenizerArrayU8,
configArrayU8,
quantized
);
}
@ -52,6 +61,7 @@ async function generate(data) {
weightsURL,
modelID,
tokenizerURL,
configURL,
quantized,
prompt,
temp,
@ -66,6 +76,7 @@ async function generate(data) {
weightsURL,
modelID,
tokenizerURL,
configURL,
quantized
);

View File

@ -26,10 +26,15 @@ pub struct Model {
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, quantized: bool) -> Result<Model, JsError> {
pub fn load(
weights: Vec<u8>,
tokenizer: Vec<u8>,
config: Vec<u8>,
quantized: bool,
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let config: Config = Config::v1_5();
let config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let start = Date::now();