From 7c65e2d187fd86307606191b9a418cb925c42335 Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 1 Jul 2023 06:36:22 +0100 Subject: [PATCH] Add a flag for custom prompt. --- candle-core/examples/llama/main.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 72603bf7..73db15e0 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -25,7 +25,7 @@ mod weights; const MAX_SEQ_LEN: usize = 4096; const DTYPE: DType = DType::F16; -const START_PROMPT: &str = r" +const DEFAULT_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, Or whether he be 'scaped away or no @@ -455,6 +455,10 @@ struct Args { /// Disable the key-value cache. #[arg(long)] no_kv_cache: bool, + + /// The initial prompt. + #[arg(long)] + prompt: Option, } #[tokio::main] @@ -500,8 +504,9 @@ async fn main() -> Result<()> { }; println!("Loaded in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer - .encode(START_PROMPT, true) + .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec();