Implement top_p / nucleus sampling (#819)

* Implement top_p / nucleus sampling

* Update changelog

* rustfmt

* Add tests

* Fix clippy warning

* Fix another clippy error
This commit is contained in:
Juarez Bochi
2023-09-12 09:10:16 -07:00
committed by GitHub
parent 42da17694a
commit 805bf9ffa7
12 changed files with 199 additions and 43 deletions

View File

@ -47,7 +47,7 @@ impl Model {
tokenizer,
model: weights,
});
let logits_processor = LogitsProcessor::new(299792458, None);
let logits_processor = LogitsProcessor::new(299792458, None, None);
match model {
Ok(inner) => Ok(Self {
inner,
@ -69,6 +69,7 @@ impl Model {
&mut self,
prompt: String,
temp: f64,
top_p: f64,
repeat_penalty: f32,
seed: u64,
) -> Result<String, JsError> {
@ -80,7 +81,12 @@ impl Model {
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
self.logits_processor = LogitsProcessor::new(seed, temp);
let top_p = if top_p <= 0. || top_p >= 1. {
None
} else {
Some(top_p)
};
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self