mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Line up the llama.cpp implementation with the candle one. (#518)
* Separate the prompt stats from the post-prompt ones in the quantized example. * Slightly nicer output printing. * Line up with the llama.cpp implementation.
This commit is contained in:
@ -81,15 +81,30 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
impl LayerWeights {
|
impl LayerWeights {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let _enter = self.span_rot.enter();
|
let _enter = self.span_rot.enter();
|
||||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self
|
||||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
.cos
|
||||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
.narrow(0, index_pos, seq_len)?
|
||||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
.reshape((seq_len, n_embd / 2, 1))?;
|
||||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
let sin = self
|
||||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
.sin
|
||||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
.narrow(0, index_pos, seq_len)?
|
||||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
.reshape((seq_len, n_embd / 2, 1))?;
|
||||||
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||||
|
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||||
|
// This mimics the llama.cpp behavior.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||||
|
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||||
|
// The resulting y0 and y1 are also interleaved with:
|
||||||
|
// y0 = x0*cos - x1*sin
|
||||||
|
// y1 = x0*sin + x1*cos
|
||||||
|
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||||
|
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||||
|
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||||
|
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||||
|
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||||
|
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||||
|
let rope = rope.flatten_from(D::Minus2)?;
|
||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,9 +187,6 @@ impl ModelWeights {
|
|||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
// This is different from the paper, see:
|
|
||||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
|
||||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
|
||||||
let cos = idx_theta.cos()?;
|
let cos = idx_theta.cos()?;
|
||||||
let sin = idx_theta.sin()?;
|
let sin = idx_theta.sin()?;
|
||||||
|
|
||||||
@ -335,6 +347,31 @@ impl Args {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
||||||
|
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||||
|
// heuristics as it seems to work well enough for this example. See the following for more
|
||||||
|
// details:
|
||||||
|
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||||
|
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||||
|
let text = text.replace('▁', " ");
|
||||||
|
let ascii = text
|
||||||
|
.strip_prefix("<0x")
|
||||||
|
.and_then(|t| t.strip_suffix('>'))
|
||||||
|
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
||||||
|
match ascii {
|
||||||
|
None => print!("{text}"),
|
||||||
|
Some(ascii) => {
|
||||||
|
if let Some(chr) = char::from_u32(ascii as u32) {
|
||||||
|
if chr.is_ascii() {
|
||||||
|
print!("{chr}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = std::io::stdout().flush();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
@ -395,39 +432,40 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tokens = tokens.get_ids().to_vec();
|
let prompt_tokens = tokens.get_ids().to_vec();
|
||||||
let mut index_pos = 0;
|
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
let mut token_generated = 0;
|
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
for index in 0..args.sample_len {
|
|
||||||
let context_size = if index == 0 { tokens.len() } else { 1 };
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let mut next_token = {
|
||||||
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, index_pos)?;
|
let logits = model.forward(&input, 0)?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
index_pos += ctxt.len();
|
logits_processor.sample(&logits)?
|
||||||
|
};
|
||||||
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
|
print_token(next_token, &tokenizer);
|
||||||
|
|
||||||
let next_token = logits_processor.sample(&logits)?;
|
let to_sample = args.sample_len.saturating_sub(1);
|
||||||
token_generated += 1;
|
let start_post_prompt = std::time::Instant::now();
|
||||||
tokens.push(next_token);
|
for index in 0..to_sample {
|
||||||
|
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
let logits = logits.squeeze(0)?;
|
||||||
// details:
|
next_token = logits_processor.sample(&logits)?;
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
print_token(next_token, &tokenizer);
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
print!("{text}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_post_prompt.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n\n{} tokens generated ({:.2} token/s)\n",
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
token_generated,
|
prompt_tokens.len(),
|
||||||
token_generated as f64 / dt.as_secs_f64(),
|
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"{:4} tokens generated: {:.2} token/s",
|
||||||
|
to_sample,
|
||||||
|
to_sample as f64 / dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user