mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
[Wasm] Add puffin phi model to wasm (#1166)
* load config from file, add puffin phi links * format * add prompt examples
This commit is contained in:
@ -4,11 +4,12 @@ use crate::models::with_tracing::{linear, Embedding as E, Linear};
|
|||||||
/// https://arxiv.org/abs/2309.05463
|
/// https://arxiv.org/abs/2309.05463
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, VarBuilder};
|
use candle_nn::{Activation, VarBuilder};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub(crate) vocab_size: usize,
|
pub(crate) vocab_size: usize,
|
||||||
pub(crate) n_positions: usize,
|
pub(crate) n_positions: usize,
|
||||||
|
@ -13,7 +13,8 @@
|
|||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<link
|
<link
|
||||||
rel="stylesheet"
|
rel="stylesheet"
|
||||||
href="https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/styles/default.min.css" />
|
href="https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/styles/default.min.css"
|
||||||
|
/>
|
||||||
<style>
|
<style>
|
||||||
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||||
html,
|
html,
|
||||||
@ -36,27 +37,110 @@
|
|||||||
<script type="module">
|
<script type="module">
|
||||||
import snarkdown from "https://cdn.skypack.dev/snarkdown";
|
import snarkdown from "https://cdn.skypack.dev/snarkdown";
|
||||||
import hljs from "https://cdn.skypack.dev/highlight.js";
|
import hljs from "https://cdn.skypack.dev/highlight.js";
|
||||||
|
|
||||||
const TOKENIZER_URL =
|
|
||||||
"https://huggingface.co/microsoft/phi-1_5/raw/main/tokenizer.json";
|
|
||||||
// models base url
|
// models base url
|
||||||
const MODELS = {
|
const MODELS = {
|
||||||
phi_1_5_quantized: {
|
phi_1_5_quantized: {
|
||||||
base_url:
|
base_url:
|
||||||
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
|
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
|
||||||
model: "model-q4k.gguf",
|
model: "model-q4k.gguf",
|
||||||
|
tokenizer: "tokenizer.json",
|
||||||
|
config: "phi-1_5.json",
|
||||||
quantized: true,
|
quantized: true,
|
||||||
seq_len: 2048,
|
seq_len: 2048,
|
||||||
|
size: "800 MB",
|
||||||
},
|
},
|
||||||
phi_1_5_quantized_2: {
|
phi_1_5_quantized_2: {
|
||||||
base_url:
|
base_url:
|
||||||
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
|
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
|
||||||
model: "model-q80.gguf",
|
model: "model-q80.gguf",
|
||||||
|
tokenizer: "tokenizer.json",
|
||||||
|
config: "phi-1_5.json",
|
||||||
quantized: true,
|
quantized: true,
|
||||||
seq_len: 2048,
|
seq_len: 2048,
|
||||||
|
size: "1.51 GB",
|
||||||
|
},
|
||||||
|
puffin_phi_v2_quantized: {
|
||||||
|
base_url:
|
||||||
|
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
|
||||||
|
model: "model-puffin-phi-v2-q4k.gguf",
|
||||||
|
tokenizer: "tokenizer-puffin-phi-v2.json",
|
||||||
|
config: "puffin-phi-v2.json",
|
||||||
|
quantized: true,
|
||||||
|
seq_len: 2048,
|
||||||
|
size: "798 MB",
|
||||||
|
},
|
||||||
|
puffin_phi_v2_quantized_2: {
|
||||||
|
base_url:
|
||||||
|
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
|
||||||
|
model: "model-puffin-phi-v2-q80.gguf",
|
||||||
|
tokenizer: "tokenizer-puffin-phi-v2.json",
|
||||||
|
config: "puffin-phi-v2.json",
|
||||||
|
quantized: true,
|
||||||
|
seq_len: 2048,
|
||||||
|
size: "1.50 GB",
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const TEMPLATES = [
|
||||||
|
{
|
||||||
|
title: "Simple prompt",
|
||||||
|
prompt: `Sebastien is in London today, it’s the middle of July yet it’s raining, so Sebastien is feeling gloomy. He`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Think step by step",
|
||||||
|
prompt: `Suppose Alice originally had 3 apples, then Bob gave Alice 7 apples, then Alice gave Cook 5 apples, and then Tim gave Alice 3x the amount of apples Alice had. How many apples does Alice have now?
|
||||||
|
Let’s think step by step.`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Explaing a code snippet",
|
||||||
|
prompt: `What does this script do?
|
||||||
|
\`\`\`python
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
s.bind(('', 0))
|
||||||
|
s.listen(1)
|
||||||
|
conn, addr = s.accept()
|
||||||
|
print('Connected by', addr)
|
||||||
|
return conn.getsockname()[1]
|
||||||
|
\`\`\`
|
||||||
|
Let’s think step by step.`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Question answering",
|
||||||
|
prompt: `What is the capital of France?
|
||||||
|
Answer:`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Chat mode",
|
||||||
|
prompt: `Alice: Can you tell me how to create a python application to go through all the files
|
||||||
|
in one directory where the file’s name DOES NOT end with '.json'?
|
||||||
|
Bob:`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Python code completion",
|
||||||
|
prompt: `"""write a python function called batch(function, list) which call function(x) for x in
|
||||||
|
list in parallel"""
|
||||||
|
Solution:`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Python Sample",
|
||||||
|
prompt: `"""Can you make sure those histograms appear side by side on the same plot:
|
||||||
|
\`\`\`python
|
||||||
|
plt.hist(intreps_retrained[0][1].view(64,-1).norm(dim=1).detach().cpu().numpy(), bins = 20)
|
||||||
|
plt.hist(intreps_pretrained[0][1].view(64,-1).norm(dim=1).detach().cpu().numpy(), bins = 20)
|
||||||
|
\`\`\`
|
||||||
|
"""`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Write a Twitter post",
|
||||||
|
prompt: `Write a twitter post for the discovery of gravitational wave.
|
||||||
|
Twitter Post:`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "Write a review",
|
||||||
|
prompt: `Write a polite review complaining that the video game 'Random Game' was too badly optimized and it burned my laptop.
|
||||||
|
Very polite review:`,
|
||||||
|
},
|
||||||
|
];
|
||||||
const phiWorker = new Worker("./phiWorker.js", {
|
const phiWorker = new Worker("./phiWorker.js", {
|
||||||
type: "module",
|
type: "module",
|
||||||
});
|
});
|
||||||
@ -65,6 +149,8 @@
|
|||||||
const modelID = getValue("model");
|
const modelID = getValue("model");
|
||||||
const model = MODELS[modelID];
|
const model = MODELS[modelID];
|
||||||
const weightsURL = model.base_url + model.model;
|
const weightsURL = model.base_url + model.model;
|
||||||
|
const tokenizerURL = model.base_url + model.tokenizer;
|
||||||
|
const configURL = model.base_url + model.config;
|
||||||
|
|
||||||
const prompt = getValue("prompt").trim();
|
const prompt = getValue("prompt").trim();
|
||||||
const temperature = getValue("temperature");
|
const temperature = getValue("temperature");
|
||||||
@ -107,7 +193,8 @@
|
|||||||
phiWorker.postMessage({
|
phiWorker.postMessage({
|
||||||
weightsURL,
|
weightsURL,
|
||||||
modelID,
|
modelID,
|
||||||
tokenizerURL: TOKENIZER_URL,
|
tokenizerURL,
|
||||||
|
configURL,
|
||||||
quantized: model.quantized,
|
quantized: model.quantized,
|
||||||
prompt,
|
prompt,
|
||||||
temp: temperature,
|
temp: temperature,
|
||||||
@ -148,9 +235,42 @@
|
|||||||
const clearBtn = document.querySelector("#clear-btn");
|
const clearBtn = document.querySelector("#clear-btn");
|
||||||
const runBtn = document.querySelector("#run");
|
const runBtn = document.querySelector("#run");
|
||||||
const modelSelect = document.querySelector("#model");
|
const modelSelect = document.querySelector("#model");
|
||||||
|
const promptTemplates = document.querySelector("#prompt-templates");
|
||||||
let runController = new AbortController();
|
let runController = new AbortController();
|
||||||
let isRunning = false;
|
let isRunning = false;
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", () => {
|
||||||
|
for (const [id, model] of Object.entries(MODELS)) {
|
||||||
|
const option = document.createElement("option");
|
||||||
|
option.value = id;
|
||||||
|
option.innerText = `${id} (${model.size})`;
|
||||||
|
modelSelect.appendChild(option);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [i, { title, prompt }] of TEMPLATES.entries()) {
|
||||||
|
const div = document.createElement("div");
|
||||||
|
const input = document.createElement("input");
|
||||||
|
input.type = "radio";
|
||||||
|
input.name = "task";
|
||||||
|
input.id = `templates-${i}`;
|
||||||
|
input.classList.add("font-light", "cursor-pointer");
|
||||||
|
input.value = prompt;
|
||||||
|
const label = document.createElement("label");
|
||||||
|
label.htmlFor = `templates-${i}`;
|
||||||
|
label.classList.add("cursor-pointer");
|
||||||
|
label.innerText = title;
|
||||||
|
div.appendChild(input);
|
||||||
|
div.appendChild(label);
|
||||||
|
promptTemplates.appendChild(div);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
promptTemplates.addEventListener("change", (e) => {
|
||||||
|
const template = e.target.value;
|
||||||
|
prompt.value = template;
|
||||||
|
prompt.style.height = "auto";
|
||||||
|
prompt.style.height = prompt.scrollHeight + "px";
|
||||||
|
});
|
||||||
modelSelect.addEventListener("change", (e) => {
|
modelSelect.addEventListener("change", (e) => {
|
||||||
const model = MODELS[e.target.value];
|
const model = MODELS[e.target.value];
|
||||||
document.querySelector("#max-seq").max = model.seq_len;
|
document.querySelector("#max-seq").max = model.seq_len;
|
||||||
@ -217,10 +337,27 @@
|
|||||||
<a
|
<a
|
||||||
href="https://arxiv.org/pdf/2309.05463.pdf#page=8"
|
href="https://arxiv.org/pdf/2309.05463.pdf#page=8"
|
||||||
class="link"
|
class="link"
|
||||||
target="_blank">
|
target="_blank"
|
||||||
|
>
|
||||||
technical report </a
|
technical report </a
|
||||||
>.
|
>.
|
||||||
</p>
|
</p>
|
||||||
|
<p class="max-w-lg">
|
||||||
|
You can also try
|
||||||
|
<a
|
||||||
|
href="https://huggingface.co/teknium/Puffin-Phi-v2"
|
||||||
|
class="link"
|
||||||
|
target="_blank"
|
||||||
|
>Puffin-Phi V2
|
||||||
|
</a>
|
||||||
|
quantized version model, a fine-tuned version of Phi-1.5 on the
|
||||||
|
<a
|
||||||
|
href="https://huggingface.co/datasets/LDJnr/Puffin"
|
||||||
|
class="link"
|
||||||
|
target="_blank"
|
||||||
|
>Puffin dataset
|
||||||
|
</a>
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<p class="text-xs italic max-w-lg">
|
<p class="text-xs italic max-w-lg">
|
||||||
@ -234,26 +371,25 @@
|
|||||||
<label for="model" class="font-medium">Models Options: </label>
|
<label for="model" class="font-medium">Models Options: </label>
|
||||||
<select
|
<select
|
||||||
id="model"
|
id="model"
|
||||||
class="border-2 border-gray-500 rounded-md font-light">
|
class="border-2 border-gray-500 rounded-md font-light"
|
||||||
<option value="phi_1_5_quantized" selected>
|
></select>
|
||||||
phi 1.5 quantized q4k (800 MB)
|
</div>
|
||||||
</option>
|
<div>
|
||||||
<option value="phi_1_5_quantized_2">
|
<h3 class="font-medium">Prompt Templates</h3>
|
||||||
phi 1.5 quantized q80 (1.51 GB)
|
<form id="prompt-templates" class="flex flex-col gap-1 my-2"></form>
|
||||||
</option>
|
|
||||||
<!-- <option value="phi_1_5">phi 1.5 (2.84 GB)</option> -->
|
|
||||||
</select>
|
|
||||||
</div>
|
</div>
|
||||||
<form
|
<form
|
||||||
id="form"
|
id="form"
|
||||||
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
|
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
|
||||||
|
>
|
||||||
<input type="submit" hidden />
|
<input type="submit" hidden />
|
||||||
<textarea
|
<textarea
|
||||||
type="text"
|
type="text"
|
||||||
id="prompt"
|
id="prompt"
|
||||||
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
|
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
|
||||||
oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'"
|
oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'"
|
||||||
placeholder="Add your prompt here...">
|
placeholder="Add your prompt here..."
|
||||||
|
>
|
||||||
Write a detailed analogy between mathematics and a lighthouse.
|
Write a detailed analogy between mathematics and a lighthouse.
|
||||||
Answer:</textarea
|
Answer:</textarea
|
||||||
>
|
>
|
||||||
@ -262,18 +398,21 @@ Answer:</textarea
|
|||||||
fill="none"
|
fill="none"
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
width="40"
|
width="40"
|
||||||
viewBox="0 0 70 40">
|
viewBox="0 0 70 40"
|
||||||
|
>
|
||||||
<path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
|
<path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
|
||||||
<path
|
<path
|
||||||
d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
|
d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
|
||||||
opacity=".5"
|
opacity=".5"
|
||||||
stroke="#1F2937"
|
stroke="#1F2937"
|
||||||
stroke-width="2" />
|
stroke-width="2"
|
||||||
|
/>
|
||||||
</svg>
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
id="run"
|
id="run"
|
||||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
|
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
|
||||||
|
>
|
||||||
Run
|
Run
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
@ -292,9 +431,11 @@ Answer:</textarea
|
|||||||
max="2048"
|
max="2048"
|
||||||
step="1"
|
step="1"
|
||||||
value="200"
|
value="200"
|
||||||
oninput="this.nextElementSibling.value = Number(this.value)" />
|
oninput="this.nextElementSibling.value = Number(this.value)"
|
||||||
|
/>
|
||||||
<output
|
<output
|
||||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
|
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||||
|
>
|
||||||
200</output
|
200</output
|
||||||
>
|
>
|
||||||
<label class="text-sm font-medium" for="temperature"
|
<label class="text-sm font-medium" for="temperature"
|
||||||
@ -308,9 +449,11 @@ Answer:</textarea
|
|||||||
max="2"
|
max="2"
|
||||||
step="0.01"
|
step="0.01"
|
||||||
value="0.00"
|
value="0.00"
|
||||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
|
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||||
|
/>
|
||||||
<output
|
<output
|
||||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
|
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||||
|
>
|
||||||
0.00</output
|
0.00</output
|
||||||
>
|
>
|
||||||
<label class="text-sm font-medium" for="top-p">Top-p</label>
|
<label class="text-sm font-medium" for="top-p">Top-p</label>
|
||||||
@ -322,9 +465,11 @@ Answer:</textarea
|
|||||||
max="1"
|
max="1"
|
||||||
step="0.01"
|
step="0.01"
|
||||||
value="1.00"
|
value="1.00"
|
||||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
|
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||||
|
/>
|
||||||
<output
|
<output
|
||||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
|
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||||
|
>
|
||||||
1.00</output
|
1.00</output
|
||||||
>
|
>
|
||||||
|
|
||||||
@ -340,7 +485,8 @@ Answer:</textarea
|
|||||||
max="2"
|
max="2"
|
||||||
step="0.01"
|
step="0.01"
|
||||||
value="1.10"
|
value="1.10"
|
||||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
|
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||||
|
/>
|
||||||
<output
|
<output
|
||||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||||
>1.10</output
|
>1.10</output
|
||||||
@ -351,11 +497,13 @@ Answer:</textarea
|
|||||||
id="seed"
|
id="seed"
|
||||||
name="seed"
|
name="seed"
|
||||||
value="299792458"
|
value="299792458"
|
||||||
class="font-light border border-gray-700 text-right rounded-md p-2" />
|
class="font-light border border-gray-700 text-right rounded-md p-2"
|
||||||
|
/>
|
||||||
<button
|
<button
|
||||||
id="run"
|
id="run"
|
||||||
onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
|
onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
|
||||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
|
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm"
|
||||||
|
>
|
||||||
Rand
|
Rand
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
@ -364,11 +512,13 @@ Answer:</textarea
|
|||||||
<div>
|
<div>
|
||||||
<h3 class="font-medium">Generation:</h3>
|
<h3 class="font-medium">Generation:</h3>
|
||||||
<div
|
<div
|
||||||
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2">
|
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
|
||||||
|
>
|
||||||
<div
|
<div
|
||||||
id="output-counter"
|
id="output-counter"
|
||||||
hidden
|
hidden
|
||||||
class="ml-auto font-semibold grid-rows-1 text-sm"></div>
|
class="ml-auto font-semibold grid-rows-1 text-sm"
|
||||||
|
></div>
|
||||||
<p hidden id="output-generation" class="grid-rows-2"></p>
|
<p hidden id="output-generation" class="grid-rows-2"></p>
|
||||||
<span id="output-status" class="m-auto font-light"
|
<span id="output-status" class="m-auto font-light"
|
||||||
>No output yet</span
|
>No output yet</span
|
||||||
|
@ -15,21 +15,30 @@ async function fetchArrayBuffer(url) {
|
|||||||
class Phi {
|
class Phi {
|
||||||
static instance = {};
|
static instance = {};
|
||||||
|
|
||||||
static async getInstance(weightsURL, modelID, tokenizerURL, quantized) {
|
static async getInstance(
|
||||||
|
weightsURL,
|
||||||
|
modelID,
|
||||||
|
tokenizerURL,
|
||||||
|
configURL,
|
||||||
|
quantized
|
||||||
|
) {
|
||||||
// load individual modelID only once
|
// load individual modelID only once
|
||||||
if (!this.instance[modelID]) {
|
if (!this.instance[modelID]) {
|
||||||
await init();
|
await init();
|
||||||
|
|
||||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||||
|
|
||||||
const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
|
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
|
||||||
fetchArrayBuffer(weightsURL),
|
await Promise.all([
|
||||||
fetchArrayBuffer(tokenizerURL),
|
fetchArrayBuffer(weightsURL),
|
||||||
]);
|
fetchArrayBuffer(tokenizerURL),
|
||||||
|
fetchArrayBuffer(configURL),
|
||||||
|
]);
|
||||||
|
|
||||||
this.instance[modelID] = new Model(
|
this.instance[modelID] = new Model(
|
||||||
weightsArrayU8,
|
weightsArrayU8,
|
||||||
tokenizerArrayU8,
|
tokenizerArrayU8,
|
||||||
|
configArrayU8,
|
||||||
quantized
|
quantized
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -52,6 +61,7 @@ async function generate(data) {
|
|||||||
weightsURL,
|
weightsURL,
|
||||||
modelID,
|
modelID,
|
||||||
tokenizerURL,
|
tokenizerURL,
|
||||||
|
configURL,
|
||||||
quantized,
|
quantized,
|
||||||
prompt,
|
prompt,
|
||||||
temp,
|
temp,
|
||||||
@ -66,6 +76,7 @@ async function generate(data) {
|
|||||||
weightsURL,
|
weightsURL,
|
||||||
modelID,
|
modelID,
|
||||||
tokenizerURL,
|
tokenizerURL,
|
||||||
|
configURL,
|
||||||
quantized
|
quantized
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -26,10 +26,15 @@ pub struct Model {
|
|||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
impl Model {
|
impl Model {
|
||||||
#[wasm_bindgen(constructor)]
|
#[wasm_bindgen(constructor)]
|
||||||
pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, quantized: bool) -> Result<Model, JsError> {
|
pub fn load(
|
||||||
|
weights: Vec<u8>,
|
||||||
|
tokenizer: Vec<u8>,
|
||||||
|
config: Vec<u8>,
|
||||||
|
quantized: bool,
|
||||||
|
) -> Result<Model, JsError> {
|
||||||
console_error_panic_hook::set_once();
|
console_error_panic_hook::set_once();
|
||||||
console_log!("loading model");
|
console_log!("loading model");
|
||||||
let config: Config = Config::v1_5();
|
let config: Config = serde_json::from_slice(&config)?;
|
||||||
let tokenizer =
|
let tokenizer =
|
||||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||||
let start = Date::now();
|
let start = Date::now();
|
||||||
|
Reference in New Issue
Block a user