mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
This commit is contained in:
@ -47,7 +47,7 @@ impl Model {
|
||||
tokenizer,
|
||||
model: weights,
|
||||
});
|
||||
let logits_processor = LogitsProcessor::new(299792458, None);
|
||||
let logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
match model {
|
||||
Ok(inner) => Ok(Self {
|
||||
inner,
|
||||
@ -69,6 +69,7 @@ impl Model {
|
||||
&mut self,
|
||||
prompt: String,
|
||||
temp: f64,
|
||||
top_p: f64,
|
||||
repeat_penalty: f32,
|
||||
seed: u64,
|
||||
) -> Result<String, JsError> {
|
||||
@ -80,7 +81,12 @@ impl Model {
|
||||
}
|
||||
}
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let top_p = if top_p <= 0. || top_p >= 1. {
|
||||
None
|
||||
} else {
|
||||
Some(top_p)
|
||||
};
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
self.repeat_penalty = repeat_penalty;
|
||||
self.tokens.clear();
|
||||
let tokens = self
|
||||
|
Reference in New Issue
Block a user