mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
This commit is contained in:
@ -56,6 +56,7 @@
|
||||
const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
|
||||
const prompt = getValue("prompt");
|
||||
const temperature = getValue("temperature");
|
||||
const topP = getValue("top-p");
|
||||
const repeatPenalty = getValue("repeat_penalty");
|
||||
const seed = getValue("seed");
|
||||
const maxSeqLen = getValue("max-seq");
|
||||
@ -99,6 +100,7 @@
|
||||
tokenizerURL: "tokenizer.json",
|
||||
prompt,
|
||||
temp: temperature,
|
||||
top_p: topP,
|
||||
repeatPenalty,
|
||||
seed: BigInt(seed),
|
||||
maxSeqLen,
|
||||
@ -251,7 +253,7 @@
|
||||
<input
|
||||
type="range"
|
||||
id="max-seq"
|
||||
name="temperature"
|
||||
name="max-seq"
|
||||
min="1"
|
||||
max="256"
|
||||
step="1"
|
||||
@ -279,6 +281,22 @@
|
||||
>
|
||||
0.50</output
|
||||
>
|
||||
<label class="text-sm font-medium" for="top-p">Top-p</label>
|
||||
<input
|
||||
type="range"
|
||||
id="top-p"
|
||||
name="top-p"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.01"
|
||||
value="1.00"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
||||
/>
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||
>
|
||||
1.00</output
|
||||
>
|
||||
|
||||
<label class="text-sm font-medium" for="repeat_penalty"
|
||||
>Repeat Penalty</label
|
||||
|
@ -46,6 +46,7 @@ pub struct App {
|
||||
status: String,
|
||||
loaded: bool,
|
||||
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
|
||||
top_p: std::rc::Rc<std::cell::RefCell<f64>>,
|
||||
prompt: std::rc::Rc<std::cell::RefCell<String>>,
|
||||
generated: String,
|
||||
n_tokens: usize,
|
||||
@ -81,6 +82,7 @@ impl Component for App {
|
||||
status,
|
||||
n_tokens: 0,
|
||||
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
|
||||
top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
|
||||
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
|
||||
generated: String::new(),
|
||||
current_decode: None,
|
||||
@ -122,10 +124,11 @@ impl Component for App {
|
||||
self.n_tokens = 0;
|
||||
self.generated.clear();
|
||||
let temp = *self.temperature.borrow();
|
||||
let top_p = *self.top_p.borrow();
|
||||
let prompt = self.prompt.borrow().clone();
|
||||
console_log!("temp: {}, prompt: {}", temp, prompt);
|
||||
console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
|
||||
ctx.link()
|
||||
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
|
||||
.send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
|
||||
}
|
||||
true
|
||||
}
|
||||
@ -177,13 +180,21 @@ impl Component for App {
|
||||
fn view(&self, ctx: &Context<Self>) -> Html {
|
||||
use yew::TargetCast;
|
||||
let temperature = self.temperature.clone();
|
||||
let oninput = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
if let Ok(temp) = f64::from_str(&input.value()) {
|
||||
*temperature.borrow_mut() = temp
|
||||
}
|
||||
Msg::Refresh
|
||||
});
|
||||
let top_p = self.top_p.clone();
|
||||
let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
if let Ok(top_p_input) = f64::from_str(&input.value()) {
|
||||
*top_p.borrow_mut() = top_p_input
|
||||
}
|
||||
Msg::Refresh
|
||||
});
|
||||
let prompt = self.prompt.clone();
|
||||
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
|
||||
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
|
||||
@ -201,9 +212,13 @@ impl Component for App {
|
||||
</p>
|
||||
</div>
|
||||
{"temperature \u{00a0} "}
|
||||
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
|
||||
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} oninput={oninput_temperature} id="temp"/>
|
||||
{format!(" \u{00a0} {}", self.temperature.borrow())}
|
||||
<br/ >
|
||||
{"top_p \u{00a0} "}
|
||||
<input type="range" min="0." max="1.0" step="0.05" value={self.top_p.borrow().to_string()} oninput={oninput_top_p} id="top_p"/>
|
||||
{format!(" \u{00a0} {}", self.top_p.borrow())}
|
||||
<br/ >
|
||||
{"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
|
||||
<br/ >
|
||||
{
|
||||
|
@ -47,7 +47,7 @@ impl Model {
|
||||
tokenizer,
|
||||
model: weights,
|
||||
});
|
||||
let logits_processor = LogitsProcessor::new(299792458, None);
|
||||
let logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
match model {
|
||||
Ok(inner) => Ok(Self {
|
||||
inner,
|
||||
@ -69,6 +69,7 @@ impl Model {
|
||||
&mut self,
|
||||
prompt: String,
|
||||
temp: f64,
|
||||
top_p: f64,
|
||||
repeat_penalty: f32,
|
||||
seed: u64,
|
||||
) -> Result<String, JsError> {
|
||||
@ -80,7 +81,12 @@ impl Model {
|
||||
}
|
||||
}
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let top_p = if top_p <= 0. || top_p >= 1. {
|
||||
None
|
||||
} else {
|
||||
Some(top_p)
|
||||
};
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
self.repeat_penalty = repeat_penalty;
|
||||
self.tokens.clear();
|
||||
let tokens = self
|
||||
|
@ -62,12 +62,18 @@ impl Model {
|
||||
link: &WorkerLink<Worker>,
|
||||
id: HandlerId,
|
||||
temp: f64,
|
||||
top_p: f64,
|
||||
prompt: String,
|
||||
) -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
console_log!("{temp:?} {prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
let top_p = if top_p <= 0. || top_p >= 1.0 {
|
||||
None
|
||||
} else {
|
||||
Some(top_p)
|
||||
};
|
||||
console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
|
||||
let mut index_pos = 0;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
@ -268,7 +274,7 @@ pub struct Worker {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum WorkerInput {
|
||||
ModelData(ModelData),
|
||||
Run(f64, String),
|
||||
Run(f64, f64, String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::Run(temp, prompt) => match &mut self.model {
|
||||
WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
|
||||
None => Err("model has not been set yet".to_string()),
|
||||
Some(model) => {
|
||||
{
|
||||
@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
}
|
||||
let result = model
|
||||
.run(&self.link, id, temp, prompt)
|
||||
.run(&self.link, id, temp, top_p, prompt)
|
||||
.map_err(|e| e.to_string());
|
||||
Ok(WorkerOutput::GenerationDone(result))
|
||||
}
|
||||
|
Reference in New Issue
Block a user