Interactive mode for the quantized model. (#690)

This commit is contained in:
Laurent Mazare
2023-08-31 11:52:42 +02:00
committed by GitHub
parent 94aa234dfd
commit 7509c98970
2 changed files with 113 additions and 59 deletions

View File

@ -17,6 +17,13 @@ use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is "; const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Debug)]
enum Prompt {
Interactive,
Chat,
One(String),
}
#[derive(Clone, Debug, Copy, ValueEnum)] #[derive(Clone, Debug, Copy, ValueEnum)]
enum Which { enum Which {
#[value(name = "7b")] #[value(name = "7b")]
@ -46,7 +53,9 @@ struct Args {
#[arg(long)] #[arg(long)]
model: Option<String>, model: Option<String>,
/// The initial prompt. /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)] #[arg(long)]
prompt: Option<String>, prompt: Option<String>,
@ -247,62 +256,103 @@ fn main() -> anyhow::Result<()> {
println!("model built"); println!("model built");
let tokenizer = args.tokenizer()?; let tokenizer = args.tokenizer()?;
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let prompt = match args.prompt.as_deref() {
let tokens = tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?; Some("chat") => Prompt::Chat,
if args.verbose_prompt { Some("interactive") => Prompt::Interactive,
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { Some(s) => Prompt::One(s.to_string()),
let token = token.replace('▁', " ").replace("<0x0A>", "\n"); None => Prompt::One(DEFAULT_PROMPT.to_string()),
println!("{id:7} -> '{token}'"); };
let mut pre_prompt_tokens = vec![];
loop {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
std::io::stdin().read_line(&mut prompt)?;
if prompt.ends_with('\n') {
prompt.pop();
if prompt.ends_with('\r') {
prompt.pop();
}
}
prompt
}
};
print!("{}", &prompt_str);
let tokens = tokenizer
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
if args.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
let to_sample = args.sample_len.saturating_sub(1);
let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
}
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
);
match prompt {
Prompt::One(_) => break,
Prompt::Interactive => {}
Prompt::Chat => {
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
}
} }
} }
let prompt_tokens = tokens.get_ids().to_vec();
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
print!("{prompt}");
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
let to_sample = args.sample_len.saturating_sub(1);
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
}
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
);
Ok(()) Ok(())
} }

View File

@ -5,7 +5,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module}; use candle_nn::{Embedding, Module};
const MAX_SEQ_LEN: usize = 4096; pub const MAX_SEQ_LEN: usize = 4096;
struct RmsNorm { struct RmsNorm {
inner: candle_nn::LayerNorm, inner: candle_nn::LayerNorm,
@ -126,9 +126,13 @@ impl LayerWeights {
let (k, v) = match &self.kv_cache { let (k, v) = match &self.kv_cache {
None => (k, v), None => (k, v),
Some((k_cache, v_cache)) => { Some((k_cache, v_cache)) => {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; if index_pos == 0 {
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; (k, v)
(k, v) } else {
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
(k, v)
}
} }
}; };
self.kv_cache = Some((k.clone(), v.clone())); self.kv_cache = Some((k.clone(), v.clone()));