mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add an option to split the prompt. (#1766)
This commit is contained in:
@ -212,6 +212,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
verbose_prompt: bool,
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
/// Process prompt elements separately.
|
||||||
|
#[arg(long)]
|
||||||
|
split_prompt: bool,
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.1)]
|
#[arg(long, default_value_t = 1.1)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
@ -487,11 +491,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = {
|
let mut next_token = if !args.split_prompt {
|
||||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, 0)?;
|
let logits = model.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
logits_processor.sample(&logits)?
|
logits_processor.sample(&logits)?
|
||||||
|
} else {
|
||||||
|
let mut next_token = 0;
|
||||||
|
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||||
|
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, pos)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
next_token = logits_processor.sample(&logits)?
|
||||||
|
}
|
||||||
|
next_token
|
||||||
};
|
};
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
|
Reference in New Issue
Block a user