Polish the llama2 wasm ui. (#232)

* Polish the llama2 wasm ui.

* readme update.
This commit is contained in:
Laurent Mazare
2023-07-24 15:28:27 +01:00
committed by GitHub
parent 5a26cba733
commit 160ba09d30
7 changed files with 50 additions and 8 deletions

View File

@ -26,6 +26,22 @@ cargo run --example falcon --release
In order to use **CUDA** add `--features cuda` to the example command line. In order to use **CUDA** add `--features cuda` to the example command line.
There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online:
[whisper](https://laurentmazare.github.io/candle-whisper/index.html),
[llama2](https://laurentmazare.github.io/candle-llama2/index.html).
For llama2, run the following command to retrieve the weight files and start a
test server:
```bash
cd candle-wasm-examples/llama2-c
wget https://karpathy.ai/llama2c/model.bin
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
trunk serve --release --public-url /candle-llama2/ --port 8081
```
And then browse to
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
## Features ## Features

View File

@ -24,6 +24,7 @@ serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
# Wasm specific crates. # Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8" gloo = "0.8"
js-sys = "0.3.64" js-sys = "0.3.64"

View File

@ -1,5 +1,6 @@
use crate::console_log; use crate::console_log;
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
use std::str::FromStr;
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture; use wasm_bindgen_futures::JsFuture;
use yew::{html, Component, Context, Html}; use yew::{html, Component, Context, Html};
@ -42,6 +43,7 @@ pub struct CurrentDecode {
pub struct App { pub struct App {
status: String, status: String,
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
generated: String, generated: String,
current_decode: Option<CurrentDecode>, current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>, worker: Box<dyn Bridge<Worker>>,
@ -73,6 +75,7 @@ impl Component for App {
let worker = Worker::bridge(std::rc::Rc::new(cb)); let worker = Worker::bridge(std::rc::Rc::new(cb));
Self { Self {
status, status,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
generated: String::new(), generated: String::new(),
current_decode: None, current_decode: None,
worker, worker,
@ -109,7 +112,10 @@ impl Component for App {
self.current_decode = Some(CurrentDecode { start_time }); self.current_decode = Some(CurrentDecode { start_time });
self.status = "generating...".to_string(); self.status = "generating...".to_string();
self.generated.clear(); self.generated.clear();
ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run)) let temp = *self.temperature.borrow();
console_log!("temp: {}", temp);
ctx.link()
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp)))
} }
true true
} }
@ -151,8 +157,16 @@ impl Component for App {
} }
fn view(&self, ctx: &Context<Self>) -> Html { fn view(&self, ctx: &Context<Self>) -> Html {
use yew::TargetCast;
let temperature = self.temperature.clone();
let oninput = move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(temp) = f64::from_str(&input.value()) {
*temperature.borrow_mut() = temp
}
};
html! { html! {
<div> <div style="margin: 2%;">
<div><p>{"Running "} <div><p>{"Running "}
<a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a> <a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a>
{" in the browser using rust/wasm with "} {" in the browser using rust/wasm with "}
@ -161,6 +175,7 @@ impl Component for App {
<p>{"Once the weights have loaded, click on the run button to start generating content."} <p>{"Once the weights have loaded, click on the run button to start generating content."}
</p> </p>
</div> </div>
{"temperature: "}<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button> <button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>
<br/ > <br/ >
<h3> <h3>

View File

@ -1,4 +1,5 @@
fn main() { fn main() {
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace)); wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
console_error_panic_hook::set_once();
yew::Renderer::<candle_wasm_example_llama2::App>::new().render(); yew::Renderer::<candle_wasm_example_llama2::App>::new().render();
} }

View File

@ -1,4 +1,5 @@
use yew_agent::PublicWorker; use yew_agent::PublicWorker;
fn main() { fn main() {
console_error_panic_hook::set_once();
candle_wasm_example_llama2::Worker::register(); candle_wasm_example_llama2::Worker::register();
} }

View File

@ -20,7 +20,7 @@ pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool, pub use_kv_cache: bool,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor, cos: Tensor,
sin: Tensor, sin: Tensor,
device: Device, device: Device,

View File

@ -107,9 +107,11 @@ impl LogitsProcessor {
} }
impl Model { impl Model {
fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> { fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> {
let dev = Device::Cpu; let dev = Device::Cpu;
let mut logits_processor = LogitsProcessor::new(299792458, None); let temp = if temp <= 0. { None } else { Some(temp) };
console_log!("{temp:?}");
let mut logits_processor = LogitsProcessor::new(299792458, temp);
let mut index_pos = 0; let mut index_pos = 0;
let mut tokens = vec![1u32]; let mut tokens = vec![1u32];
@ -299,7 +301,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub enum WorkerInput { pub enum WorkerInput {
ModelData(ModelData), ModelData(ModelData),
Run, Run(f64),
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -332,10 +334,16 @@ impl yew_agent::Worker for Worker {
} }
Err(err) => Err(format!("model creation error {err:?}")), Err(err) => Err(format!("model creation error {err:?}")),
}, },
WorkerInput::Run => match &self.model { WorkerInput::Run(temp) => match &mut self.model {
None => Err("model has not been set yet".to_string()), None => Err("model has not been set yet".to_string()),
Some(model) => { Some(model) => {
let result = model.run(&self.link, id).map_err(|e| e.to_string()); {
let mut cache = model.cache.kvs.lock().unwrap();
for elem in cache.iter_mut() {
*elem = None
}
}
let result = model.run(&self.link, id, temp).map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result)) Ok(WorkerOutput::GenerationDone(result))
} }
}, },