mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
This commit is contained in:
@ -27,6 +27,10 @@ struct InferenceCmd {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
|
||||
None => {
|
||||
let cmd = InferenceCmd {
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
prompt: "".to_string(),
|
||||
config: None,
|
||||
model_id: "karpathy/tinyllamas".to_string(),
|
||||
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
let mut index_pos = 0;
|
||||
|
||||
print!("{}", args.prompt);
|
||||
|
Reference in New Issue
Block a user