mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Interactive mode for the quantized model. (#690)
This commit is contained in:
@ -17,6 +17,13 @@ use model::ModelWeights;
|
|||||||
|
|
||||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum Prompt {
|
||||||
|
Interactive,
|
||||||
|
Chat,
|
||||||
|
One(String),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "7b")]
|
#[value(name = "7b")]
|
||||||
@ -46,7 +53,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
/// The initial prompt.
|
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||||
|
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||||
|
/// is preserved.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
@ -247,62 +256,103 @@ fn main() -> anyhow::Result<()> {
|
|||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = match args.prompt.as_deref() {
|
||||||
let tokens = tokenizer.encode(prompt, true).map_err(anyhow::Error::msg)?;
|
Some("chat") => Prompt::Chat,
|
||||||
if args.verbose_prompt {
|
Some("interactive") => Prompt::Interactive,
|
||||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
Some(s) => Prompt::One(s.to_string()),
|
||||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||||
println!("{id:7} -> '{token}'");
|
};
|
||||||
|
|
||||||
|
let mut pre_prompt_tokens = vec![];
|
||||||
|
loop {
|
||||||
|
let prompt_str = match &prompt {
|
||||||
|
Prompt::One(prompt) => prompt.clone(),
|
||||||
|
Prompt::Interactive | Prompt::Chat => {
|
||||||
|
print!("> ");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let mut prompt = String::new();
|
||||||
|
std::io::stdin().read_line(&mut prompt)?;
|
||||||
|
if prompt.ends_with('\n') {
|
||||||
|
prompt.pop();
|
||||||
|
if prompt.ends_with('\r') {
|
||||||
|
prompt.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prompt
|
||||||
|
}
|
||||||
|
};
|
||||||
|
print!("{}", &prompt_str);
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(prompt_str, true)
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
if args.verbose_prompt {
|
||||||
|
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||||
|
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||||
|
println!("{id:7} -> '{token}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||||
|
let to_sample = args.sample_len.saturating_sub(1);
|
||||||
|
let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 {
|
||||||
|
let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN;
|
||||||
|
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||||
|
} else {
|
||||||
|
prompt_tokens
|
||||||
|
};
|
||||||
|
let mut all_tokens = vec![];
|
||||||
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||||
|
|
||||||
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
|
let mut next_token = {
|
||||||
|
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
logits_processor.sample(&logits)?
|
||||||
|
};
|
||||||
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
print_token(next_token, &tokenizer);
|
||||||
|
|
||||||
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
for index in 0..to_sample {
|
||||||
|
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&all_tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
next_token = logits_processor.sample(&logits)?;
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
print_token(next_token, &tokenizer);
|
||||||
|
}
|
||||||
|
let dt = start_post_prompt.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
|
prompt_tokens.len(),
|
||||||
|
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"{:4} tokens generated: {:.2} token/s",
|
||||||
|
to_sample,
|
||||||
|
to_sample as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
|
||||||
|
match prompt {
|
||||||
|
Prompt::One(_) => break,
|
||||||
|
Prompt::Interactive => {}
|
||||||
|
Prompt::Chat => {
|
||||||
|
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt_tokens = tokens.get_ids().to_vec();
|
|
||||||
let mut all_tokens = vec![];
|
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
|
||||||
|
|
||||||
print!("{prompt}");
|
|
||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
|
||||||
let mut next_token = {
|
|
||||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
|
||||||
let logits = model.forward(&input, 0)?;
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
logits_processor.sample(&logits)?
|
|
||||||
};
|
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
|
||||||
all_tokens.push(next_token);
|
|
||||||
print_token(next_token, &tokenizer);
|
|
||||||
|
|
||||||
let to_sample = args.sample_len.saturating_sub(1);
|
|
||||||
let start_post_prompt = std::time::Instant::now();
|
|
||||||
for index in 0..to_sample {
|
|
||||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
let logits = if args.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
args.repeat_penalty,
|
|
||||||
&all_tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
next_token = logits_processor.sample(&logits)?;
|
|
||||||
all_tokens.push(next_token);
|
|
||||||
print_token(next_token, &tokenizer);
|
|
||||||
}
|
|
||||||
let dt = start_post_prompt.elapsed();
|
|
||||||
println!(
|
|
||||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
|
||||||
prompt_tokens.len(),
|
|
||||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"{:4} tokens generated: {:.2} token/s",
|
|
||||||
to_sample,
|
|
||||||
to_sample as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ use candle::quantized::{ggml_file, gguf_file};
|
|||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Module};
|
use candle_nn::{Embedding, Module};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
inner: candle_nn::LayerNorm,
|
inner: candle_nn::LayerNorm,
|
||||||
@ -126,9 +126,13 @@ impl LayerWeights {
|
|||||||
let (k, v) = match &self.kv_cache {
|
let (k, v) = match &self.kv_cache {
|
||||||
None => (k, v),
|
None => (k, v),
|
||||||
Some((k_cache, v_cache)) => {
|
Some((k_cache, v_cache)) => {
|
||||||
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
if index_pos == 0 {
|
||||||
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
(k, v)
|
||||||
(k, v)
|
} else {
|
||||||
|
let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
|
||||||
|
let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
Reference in New Issue
Block a user