Add a flag for custom prompt.

This commit is contained in:
laurent
2023-07-01 06:36:22 +01:00
parent 2c04bff12f
commit 7c65e2d187

View File

@ -25,7 +25,7 @@ mod weights;
const MAX_SEQ_LEN: usize = 4096; const MAX_SEQ_LEN: usize = 4096;
const DTYPE: DType = DType::F16; const DTYPE: DType = DType::F16;
const START_PROMPT: &str = r" const DEFAULT_PROMPT: &str = r"
EDWARD: EDWARD:
I wonder how our princely father 'scaped, I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no Or whether he be 'scaped away or no
@ -455,6 +455,10 @@ struct Args {
/// Disable the key-value cache. /// Disable the key-value cache.
#[arg(long)] #[arg(long)]
no_kv_cache: bool, no_kv_cache: bool,
/// The initial prompt.
#[arg(long)]
prompt: Option<String>,
} }
#[tokio::main] #[tokio::main]
@ -500,8 +504,9 @@ async fn main() -> Result<()> {
}; };
println!("Loaded in {:?}", start.elapsed()); println!("Loaded in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer let mut tokens = tokenizer
.encode(START_PROMPT, true) .encode(prompt, true)
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();