Bugfix for the llama2 wasm example. (#310)

* Clean-up the llama2.c wasm example.

* Use a proper tokenizer.

* Add a prompt.

* Bugfix for the llama2 wasm example.
This commit is contained in:
Laurent Mazare
2023-08-02 17:32:36 +01:00
committed by GitHub
parent 186c308d51
commit 52414ba5c8
2 changed files with 37 additions and 9 deletions

View File

@ -46,6 +46,7 @@ pub struct App {
status: String, status: String,
loaded: bool, loaded: bool,
temperature: std::rc::Rc<std::cell::RefCell<f64>>, temperature: std::rc::Rc<std::cell::RefCell<f64>>,
prompt: std::rc::Rc<std::cell::RefCell<String>>,
generated: String, generated: String,
n_tokens: usize, n_tokens: usize,
current_decode: Option<CurrentDecode>, current_decode: Option<CurrentDecode>,
@ -80,6 +81,7 @@ impl Component for App {
status, status,
n_tokens: 0, n_tokens: 0,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)), temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
generated: String::new(), generated: String::new(),
current_decode: None, current_decode: None,
worker, worker,
@ -120,9 +122,10 @@ impl Component for App {
self.n_tokens = 0; self.n_tokens = 0;
self.generated.clear(); self.generated.clear();
let temp = *self.temperature.borrow(); let temp = *self.temperature.borrow();
console_log!("temp: {}", temp); let prompt = self.prompt.borrow().clone();
console_log!("temp: {}, prompt: {}", temp, prompt);
ctx.link() ctx.link()
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp))) .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
} }
true true
} }
@ -181,6 +184,12 @@ impl Component for App {
} }
Msg::Refresh Msg::Refresh
}); });
let prompt = self.prompt.clone();
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
*prompt.borrow_mut() = input.value();
Msg::Refresh
});
html! { html! {
<div style="margin: 2%;"> <div style="margin: 2%;">
<div><p>{"Running "} <div><p>{"Running "}
@ -195,6 +204,8 @@ impl Component for App {
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/> <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
{format!(" \u{00a0} {}", self.temperature.borrow())} {format!(" \u{00a0} {}", self.temperature.borrow())}
<br/ > <br/ >
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
<br/ >
{ {
if self.loaded{ if self.loaded{
html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>) html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>)

View File

@ -91,15 +91,30 @@ impl LogitsProcessor {
} }
impl Model { impl Model {
fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> { fn run(
&self,
link: &WorkerLink<Worker>,
id: HandlerId,
temp: f64,
prompt: String,
) -> Result<()> {
let dev = Device::Cpu; let dev = Device::Cpu;
let temp = if temp <= 0. { None } else { Some(temp) }; let temp = if temp <= 0. { None } else { Some(temp) };
console_log!("{temp:?}"); console_log!("{temp:?} {prompt}");
let mut logits_processor = LogitsProcessor::new(299792458, temp); let mut logits_processor = LogitsProcessor::new(299792458, temp);
let mut index_pos = 0; let mut index_pos = 0;
let mut tokens = vec![1u32]; let mut tokens = self
.tokenizer
.encode(prompt.to_string(), true)
.map_err(|m| candle::Error::Msg(m.to_string()))?
.get_ids()
.to_vec();
link.respond(id, Ok(WorkerOutput::Generated(prompt)));
for index in 0..self.config.seq_len - 10 { for index in 0.. {
if tokens.len() >= self.config.seq_len {
break;
}
let context_size = if self.cache.use_kv_cache && index > 0 { let context_size = if self.cache.use_kv_cache && index > 0 {
1 1
} else { } else {
@ -287,7 +302,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub enum WorkerInput { pub enum WorkerInput {
ModelData(ModelData), ModelData(ModelData),
Run(f64), Run(f64, String),
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -320,7 +335,7 @@ impl yew_agent::Worker for Worker {
} }
Err(err) => Err(format!("model creation error {err:?}")), Err(err) => Err(format!("model creation error {err:?}")),
}, },
WorkerInput::Run(temp) => match &mut self.model { WorkerInput::Run(temp, prompt) => match &mut self.model {
None => Err("model has not been set yet".to_string()), None => Err("model has not been set yet".to_string()),
Some(model) => { Some(model) => {
{ {
@ -329,7 +344,9 @@ impl yew_agent::Worker for Worker {
*elem = None *elem = None
} }
} }
let result = model.run(&self.link, id, temp).map_err(|e| e.to_string()); let result = model
.run(&self.link, id, temp, prompt)
.map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result)) Ok(WorkerOutput::GenerationDone(result))
} }
}, },