Implement top_p / nucleus sampling (#819)

* Implement top_p / nucleus sampling

* Update changelog

* rustfmt

* Add tests

* Fix clippy warning

* Fix another clippy error
This commit is contained in:
Juarez Bochi
2023-09-12 09:10:16 -07:00
committed by GitHub
parent 42da17694a
commit 805bf9ffa7
12 changed files with 199 additions and 43 deletions

View File

@ -46,6 +46,7 @@ pub struct App {
status: String,
loaded: bool,
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
top_p: std::rc::Rc<std::cell::RefCell<f64>>,
prompt: std::rc::Rc<std::cell::RefCell<String>>,
generated: String,
n_tokens: usize,
@ -81,6 +82,7 @@ impl Component for App {
status,
n_tokens: 0,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
generated: String::new(),
current_decode: None,
@ -122,10 +124,11 @@ impl Component for App {
self.n_tokens = 0;
self.generated.clear();
let temp = *self.temperature.borrow();
let top_p = *self.top_p.borrow();
let prompt = self.prompt.borrow().clone();
console_log!("temp: {}, prompt: {}", temp, prompt);
console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link()
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
}
true
}
@ -177,13 +180,21 @@ impl Component for App {
fn view(&self, ctx: &Context<Self>) -> Html {
use yew::TargetCast;
let temperature = self.temperature.clone();
let oninput = ctx.link().callback(move |e: yew::InputEvent| {
let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(temp) = f64::from_str(&input.value()) {
*temperature.borrow_mut() = temp
}
Msg::Refresh
});
let top_p = self.top_p.clone();
let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(top_p_input) = f64::from_str(&input.value()) {
*top_p.borrow_mut() = top_p_input
}
Msg::Refresh
});
let prompt = self.prompt.clone();
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
@ -201,9 +212,13 @@ impl Component for App {
</p>
</div>
{"temperature \u{00a0} "}
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/>
{format!(" \u{00a0} {}", self.temperature.borrow())}
<br/ >
{"top_p \u{00a0} "}
<input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/>
{format!(" \u{00a0} {}", self.top_p.borrow())}
<br/ >
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
<br/ >
{