mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
This commit is contained in:
@ -62,12 +62,18 @@ impl Model {
|
||||
link: &WorkerLink<Worker>,
|
||||
id: HandlerId,
|
||||
temp: f64,
|
||||
top_p: f64,
|
||||
prompt: String,
|
||||
) -> Result<()> {
|
||||
let dev = Device::Cpu;
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
console_log!("{temp:?} {prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
let top_p = if top_p <= 0. || top_p >= 1.0 {
|
||||
None
|
||||
} else {
|
||||
Some(top_p)
|
||||
};
|
||||
console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
|
||||
let mut index_pos = 0;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
@ -268,7 +274,7 @@ pub struct Worker {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum WorkerInput {
|
||||
ModelData(ModelData),
|
||||
Run(f64, String),
|
||||
Run(f64, f64, String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
Err(err) => Err(format!("model creation error {err:?}")),
|
||||
},
|
||||
WorkerInput::Run(temp, prompt) => match &mut self.model {
|
||||
WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
|
||||
None => Err("model has not been set yet".to_string()),
|
||||
Some(model) => {
|
||||
{
|
||||
@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
|
||||
}
|
||||
}
|
||||
let result = model
|
||||
.run(&self.link, id, temp, prompt)
|
||||
.run(&self.link, id, temp, top_p, prompt)
|
||||
.map_err(|e| e.to_string());
|
||||
Ok(WorkerOutput::GenerationDone(result))
|
||||
}
|
||||
|
Reference in New Issue
Block a user