From 7cef35c84d97b27d4947476847ae0ba590ac9a21 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 31 Aug 2023 18:25:21 +0200 Subject: [PATCH] Tweak some quantized args (#692) * Print the args + change the default temp/repeat penalty. * Minor formatting tweak. --- candle-examples/examples/quantized/main.rs | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index b2aa0cdb..a3f98d8e 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -67,9 +67,9 @@ struct Args { #[arg(long)] tokenizer: Option, - /// The temperature used to generate samples. - #[arg(long)] - temperature: Option, + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] @@ -84,7 +84,7 @@ struct Args { verbose_prompt: bool, /// Penalty to be applied for repeating tokens, 1. means no penalty. - #[arg(long, default_value_t = 1.0)] + #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, /// The context size to consider for the repeat penalty. @@ -188,6 +188,11 @@ fn main() -> anyhow::Result<()> { use tracing_subscriber::prelude::*; let args = Args::parse(); + let temperature = if args.temperature == 0. { + None + } else { + Some(args.temperature) + }; let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -203,6 +208,10 @@ fn main() -> anyhow::Result<()> { candle::utils::with_simd128(), candle::utils::with_f16c() ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); let model_path = args.model()?; let mut file = std::fs::File::open(&model_path)?; @@ -301,7 +310,7 @@ fn main() -> anyhow::Result<()> { prompt_tokens }; let mut all_tokens = vec![]; - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, temperature); let start_prompt_processing = std::time::Instant::now(); let mut next_token = {