mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Interactive mode for the quantized model. (#690)
This commit is contained in:
@ -17,6 +17,13 @@ use model::ModelWeights;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Prompt {
|
||||
Interactive,
|
||||
Chat,
|
||||
One(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "7b")]
|
||||
@ -46,7 +53,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt.
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
@ -247,8 +256,35 @@ fn main() -> anyhow::Result<()> {
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let tokens = tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?;
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
Some(s) => Prompt::One(s.to_string()),
|
||||
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
loop {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
std::io::stdin().read_line(&mut prompt)?;
|
||||
if prompt.ends_with('\n') {
|
||||
prompt.pop();
|
||||
if prompt.ends_with('\r') {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
prompt
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
let tokens = tokenizer
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
if args.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
@ -256,12 +292,17 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let prompt_tokens = tokens.get_ids().to_vec();
|
||||
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 {
|
||||
let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
|
||||
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||
} else {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
|
||||
print!("{prompt}");
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
@ -273,7 +314,6 @@ fn main() -> anyhow::Result<()> {
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
@ -304,5 +344,15 @@ fn main() -> anyhow::Result<()> {
|
||||
to_sample,
|
||||
to_sample as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
Prompt::One(_) => break,
|
||||
Prompt::Interactive => {}
|
||||
Prompt::Chat => {
|
||||
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
@ -126,10 +126,14 @@ impl LayerWeights {
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
|
Reference in New Issue
Block a user