mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Line up the llama.cpp implementation with the candle one. (#518)
* Separate the prompt stats from the post-prompt ones in the quantized example. * Slightly nicer output printing. * Line up with the llama.cpp implementation.
This commit is contained in:
@ -81,15 +81,30 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self
|
||||
.cos
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let sin = self
|
||||
.sin
|
||||
.narrow(0, index_pos, seq_len)?
|
||||
.reshape((seq_len, n_embd / 2, 1))?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||
// This mimics the llama.cpp behavior.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||
// The resulting y0 and y1 are also interleaved with:
|
||||
// y0 = x0*cos - x1*sin
|
||||
// y1 = x0*sin + x1*cos
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||
let rope = rope.flatten_from(D::Minus2)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
@ -172,9 +187,6 @@ impl ModelWeights {
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
|
||||
@ -335,6 +347,31 @@ impl Args {
|
||||
}
|
||||
}
|
||||
|
||||
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ");
|
||||
let ascii = text
|
||||
.strip_prefix("<0x")
|
||||
.and_then(|t| t.strip_suffix('>'))
|
||||
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
||||
match ascii {
|
||||
None => print!("{text}"),
|
||||
Some(ascii) => {
|
||||
if let Some(chr) = char::from_u32(ascii as u32) {
|
||||
if chr.is_ascii() {
|
||||
print!("{chr}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = std::io::stdout().flush();
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
@ -395,39 +432,40 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut index_pos = 0;
|
||||
let prompt_tokens = tokens.get_ids().to_vec();
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut token_generated = 0;
|
||||
|
||||
print!("{prompt}");
|
||||
for index in 0..args.sample_len {
|
||||
let context_size = if index == 0 { tokens.len() } else { 1 };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
index_pos += ctxt.len();
|
||||
logits_processor.sample(&logits)?
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
print_token(next_token, &tokenizer);
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
print_token(next_token, &tokenizer);
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{} tokens generated ({:.2} token/s)\n",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
prompt_tokens.len(),
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{:4} tokens generated: {:.2} token/s",
|
||||
to_sample,
|
||||
to_sample as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user