mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Polish the llama2 wasm ui. (#232)
* Polish the llama2 wasm ui. * readme update.
This commit is contained in:
@ -24,6 +24,7 @@ serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
js-sys = "0.3.64"
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::console_log;
|
||||
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_futures::JsFuture;
|
||||
use yew::{html, Component, Context, Html};
|
||||
@ -42,6 +43,7 @@ pub struct CurrentDecode {
|
||||
|
||||
pub struct App {
|
||||
status: String,
|
||||
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
|
||||
generated: String,
|
||||
current_decode: Option<CurrentDecode>,
|
||||
worker: Box<dyn Bridge<Worker>>,
|
||||
@ -73,6 +75,7 @@ impl Component for App {
|
||||
let worker = Worker::bridge(std::rc::Rc::new(cb));
|
||||
Self {
|
||||
status,
|
||||
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
|
||||
generated: String::new(),
|
||||
current_decode: None,
|
||||
worker,
|
||||
@ -109,7 +112,10 @@ impl Component for App {
|
||||
self.current_decode = Some(CurrentDecode { start_time });
|
||||
self.status = "generating...".to_string();
|
||||
self.generated.clear();
|
||||
ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run))
|
||||
let temp = *self.temperature.borrow();
|
||||
console_log!("temp: {}", temp);
|
||||
ctx.link()
|
||||
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp)))
|
||||
}
|
||||
true
|
||||
}
|
||||
@ -151,8 +157,16 @@ impl Component for App {
|
||||
}
|
||||
|
||||
fn view(&self, ctx: &Context<Self>) -> Html {
|
||||
use yew::TargetCast;
|
||||
let temperature = self.temperature.clone();
|
||||
let oninput = move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
if let Ok(temp) = f64::from_str(&input.value()) {
|
||||
*temperature.borrow_mut() = temp
|
||||
}
|
||||
};
|
||||
html! {
|
||||
<div>
|
||||
<div style="margin: 2%;">
|
||||
<div><p>{"Running "}
|
||||
<a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a>
|
||||
{" in the browser using rust/wasm with "}
|
||||
@ -161,6 +175,7 @@ impl Component for App {
|
||||
<p>{"Once the weights have loaded, click on the run button to start generating content."}
|
||||
</p>
|
||||
</div>
|
||||
{"temperature: "}<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
|
||||
<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>
|
||||
<br/ >
|
||||
<h3>
|
||||
|
@ -1,4 +1,5 @@
|
||||
fn main() {
|
||||
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
|
||||
console_error_panic_hook::set_once();
|
||||
yew::Renderer::<candle_wasm_example_llama2::App>::new().render();
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
use yew_agent::PublicWorker;
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
candle_wasm_example_llama2::Worker::register();
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
|
@ -107,9 +107,11 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> {
|
||||
fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, None);
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
console_log!("{temp:?}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
let mut index_pos = 0;
|
||||
let mut tokens = vec![1u32];
|
||||
|
||||
@ -299,7 +301,7 @@ pub struct Worker {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum WorkerInput {
|
||||
ModelData(ModelData),
|
||||
Run,
|
||||
Run(f64),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -332,10 +334,16 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::Run => match &self.model {
|
||||
WorkerInput::Run(temp) => match &mut self.model {
|
||||
None => Err("model has not been set yet".to_string()),
|
||||
Some(model) => {
|
||||
let result = model.run(&self.link, id).map_err(|e| e.to_string());
|
||||
{
|
||||
let mut cache = model.cache.kvs.lock().unwrap();
|
||||
for elem in cache.iter_mut() {
|
||||
*elem = None
|
||||
}
|
||||
}
|
||||
let result = model.run(&self.link, id, temp).map_err(|e| e.to_string());
|
||||
Ok(WorkerOutput::GenerationDone(result))
|
||||
}
|
||||
},
|
||||
|
Reference in New Issue
Block a user