mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add a flag for custom prompt.
This commit is contained in:
@ -25,7 +25,7 @@ mod weights;
|
|||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
const DTYPE: DType = DType::F16;
|
const DTYPE: DType = DType::F16;
|
||||||
const START_PROMPT: &str = r"
|
const DEFAULT_PROMPT: &str = r"
|
||||||
EDWARD:
|
EDWARD:
|
||||||
I wonder how our princely father 'scaped,
|
I wonder how our princely father 'scaped,
|
||||||
Or whether he be 'scaped away or no
|
Or whether he be 'scaped away or no
|
||||||
@ -455,6 +455,10 @@ struct Args {
|
|||||||
/// Disable the key-value cache.
|
/// Disable the key-value cache.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
no_kv_cache: bool,
|
no_kv_cache: bool,
|
||||||
|
|
||||||
|
/// The initial prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -500,8 +504,9 @@ async fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("Loaded in {:?}", start.elapsed());
|
println!("Loaded in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(START_PROMPT, true)
|
.encode(prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
Reference in New Issue
Block a user