mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Polish the llama2 wasm ui. (#232)
* Polish the llama2 wasm ui. * readme update.
This commit is contained in:
16
README.md
16
README.md
@ -26,6 +26,22 @@ cargo run --example falcon --release
|
|||||||
|
|
||||||
In order to use **CUDA** add `--features cuda` to the example command line.
|
In order to use **CUDA** add `--features cuda` to the example command line.
|
||||||
|
|
||||||
|
There are also some wasm examples for whisper and
|
||||||
|
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
|
||||||
|
`trunk` or try them online:
|
||||||
|
[whisper](https://laurentmazare.github.io/candle-whisper/index.html),
|
||||||
|
[llama2](https://laurentmazare.github.io/candle-llama2/index.html).
|
||||||
|
|
||||||
|
For llama2, run the following command to retrieve the weight files and start a
|
||||||
|
test server:
|
||||||
|
```bash
|
||||||
|
cd candle-wasm-examples/llama2-c
|
||||||
|
wget https://karpathy.ai/llama2c/model.bin
|
||||||
|
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
|
||||||
|
trunk serve --release --public-url /candle-llama2/ --port 8081
|
||||||
|
```
|
||||||
|
And then browse to
|
||||||
|
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ serde = { workspace = true }
|
|||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
|
||||||
# Wasm specific crates.
|
# Wasm specific crates.
|
||||||
|
console_error_panic_hook = "0.1.7"
|
||||||
getrandom = { version = "0.2", features = ["js"] }
|
getrandom = { version = "0.2", features = ["js"] }
|
||||||
gloo = "0.8"
|
gloo = "0.8"
|
||||||
js-sys = "0.3.64"
|
js-sys = "0.3.64"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use crate::console_log;
|
use crate::console_log;
|
||||||
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
|
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
|
||||||
|
use std::str::FromStr;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
use wasm_bindgen_futures::JsFuture;
|
use wasm_bindgen_futures::JsFuture;
|
||||||
use yew::{html, Component, Context, Html};
|
use yew::{html, Component, Context, Html};
|
||||||
@ -42,6 +43,7 @@ pub struct CurrentDecode {
|
|||||||
|
|
||||||
pub struct App {
|
pub struct App {
|
||||||
status: String,
|
status: String,
|
||||||
|
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
|
||||||
generated: String,
|
generated: String,
|
||||||
current_decode: Option<CurrentDecode>,
|
current_decode: Option<CurrentDecode>,
|
||||||
worker: Box<dyn Bridge<Worker>>,
|
worker: Box<dyn Bridge<Worker>>,
|
||||||
@ -73,6 +75,7 @@ impl Component for App {
|
|||||||
let worker = Worker::bridge(std::rc::Rc::new(cb));
|
let worker = Worker::bridge(std::rc::Rc::new(cb));
|
||||||
Self {
|
Self {
|
||||||
status,
|
status,
|
||||||
|
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
|
||||||
generated: String::new(),
|
generated: String::new(),
|
||||||
current_decode: None,
|
current_decode: None,
|
||||||
worker,
|
worker,
|
||||||
@ -109,7 +112,10 @@ impl Component for App {
|
|||||||
self.current_decode = Some(CurrentDecode { start_time });
|
self.current_decode = Some(CurrentDecode { start_time });
|
||||||
self.status = "generating...".to_string();
|
self.status = "generating...".to_string();
|
||||||
self.generated.clear();
|
self.generated.clear();
|
||||||
ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run))
|
let temp = *self.temperature.borrow();
|
||||||
|
console_log!("temp: {}", temp);
|
||||||
|
ctx.link()
|
||||||
|
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp)))
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
@ -151,8 +157,16 @@ impl Component for App {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn view(&self, ctx: &Context<Self>) -> Html {
|
fn view(&self, ctx: &Context<Self>) -> Html {
|
||||||
|
use yew::TargetCast;
|
||||||
|
let temperature = self.temperature.clone();
|
||||||
|
let oninput = move |e: yew::InputEvent| {
|
||||||
|
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||||
|
if let Ok(temp) = f64::from_str(&input.value()) {
|
||||||
|
*temperature.borrow_mut() = temp
|
||||||
|
}
|
||||||
|
};
|
||||||
html! {
|
html! {
|
||||||
<div>
|
<div style="margin: 2%;">
|
||||||
<div><p>{"Running "}
|
<div><p>{"Running "}
|
||||||
<a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a>
|
<a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a>
|
||||||
{" in the browser using rust/wasm with "}
|
{" in the browser using rust/wasm with "}
|
||||||
@ -161,6 +175,7 @@ impl Component for App {
|
|||||||
<p>{"Once the weights have loaded, click on the run button to start generating content."}
|
<p>{"Once the weights have loaded, click on the run button to start generating content."}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
{"temperature: "}<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
|
||||||
<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>
|
<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>
|
||||||
<br/ >
|
<br/ >
|
||||||
<h3>
|
<h3>
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
fn main() {
|
fn main() {
|
||||||
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
|
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
|
||||||
|
console_error_panic_hook::set_once();
|
||||||
yew::Renderer::<candle_wasm_example_llama2::App>::new().render();
|
yew::Renderer::<candle_wasm_example_llama2::App>::new().render();
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use yew_agent::PublicWorker;
|
use yew_agent::PublicWorker;
|
||||||
fn main() {
|
fn main() {
|
||||||
|
console_error_panic_hook::set_once();
|
||||||
candle_wasm_example_llama2::Worker::register();
|
candle_wasm_example_llama2::Worker::register();
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ pub struct Cache {
|
|||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
pub use_kv_cache: bool,
|
pub use_kv_cache: bool,
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
@ -107,9 +107,11 @@ impl LogitsProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> {
|
fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> {
|
||||||
let dev = Device::Cpu;
|
let dev = Device::Cpu;
|
||||||
let mut logits_processor = LogitsProcessor::new(299792458, None);
|
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||||
|
console_log!("{temp:?}");
|
||||||
|
let mut logits_processor = LogitsProcessor::new(299792458, temp);
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut tokens = vec![1u32];
|
let mut tokens = vec![1u32];
|
||||||
|
|
||||||
@ -299,7 +301,7 @@ pub struct Worker {
|
|||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub enum WorkerInput {
|
pub enum WorkerInput {
|
||||||
ModelData(ModelData),
|
ModelData(ModelData),
|
||||||
Run,
|
Run(f64),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
@ -332,10 +334,16 @@ impl yew_agent::Worker for Worker {
|
|||||||
}
|
}
|
||||||
Err(err) => Err(format!("model creation error {err:?}")),
|
Err(err) => Err(format!("model creation error {err:?}")),
|
||||||
},
|
},
|
||||||
WorkerInput::Run => match &self.model {
|
WorkerInput::Run(temp) => match &mut self.model {
|
||||||
None => Err("model has not been set yet".to_string()),
|
None => Err("model has not been set yet".to_string()),
|
||||||
Some(model) => {
|
Some(model) => {
|
||||||
let result = model.run(&self.link, id).map_err(|e| e.to_string());
|
{
|
||||||
|
let mut cache = model.cache.kvs.lock().unwrap();
|
||||||
|
for elem in cache.iter_mut() {
|
||||||
|
*elem = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let result = model.run(&self.link, id, temp).map_err(|e| e.to_string());
|
||||||
Ok(WorkerOutput::GenerationDone(result))
|
Ok(WorkerOutput::GenerationDone(result))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
Reference in New Issue
Block a user