Implement top_p / nucleus sampling (#819)

* Implement top_p / nucleus sampling

* Update changelog

* rustfmt

* Add tests

* Fix clippy warning

* Fix another clippy error
This commit is contained in:
Juarez Bochi
2023-09-12 09:10:16 -07:00
committed by GitHub
parent 42da17694a
commit 805bf9ffa7
12 changed files with 199 additions and 43 deletions

View File

@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value = "")]
prompt: String,
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => {
let cmd = InferenceCmd {
temperature: None,
top_p: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0;
print!("{}", args.prompt);