Line up the llama.cpp implementation with the candle one. (#518)

* Separate the prompt stats from the post-prompt ones in the quantized example.

* Slightly nicer output printing.

* Line up with the llama.cpp implementation.
This commit is contained in:
Laurent Mazare
2023-08-19 20:12:07 +01:00
committed by GitHub
parent 551409092e
commit d73ca3d28e

View File

@ -81,15 +81,30 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
impl LayerWeights { impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter(); let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, n_embd) = x.dims4()?; let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?; let cos = self
let sin = self.sin.narrow(0, index_pos, seq_len)?; .cos
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; .narrow(0, index_pos, seq_len)?
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?; .reshape((seq_len, n_embd / 2, 1))?;
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?; let sin = self
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; .sin
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; .narrow(0, index_pos, seq_len)?
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; .reshape((seq_len, n_embd / 2, 1))?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
// This mimics the llama.cpp behavior.
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
// The resulting y0 and y1 are also interleaved with:
// y0 = x0*cos - x1*sin
// y1 = x0*sin + x1*cos
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
let rope = rope.flatten_from(D::Minus2)?;
Ok(rope) Ok(rope)
} }
@ -172,9 +187,6 @@ impl ModelWeights {
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))? .reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?; .matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?; let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?; let sin = idx_theta.sin()?;
@ -335,6 +347,31 @@ impl Args {
} }
} }
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ");
let ascii = text
.strip_prefix("<0x")
.and_then(|t| t.strip_suffix('>'))
.and_then(|t| u8::from_str_radix(t, 16).ok());
match ascii {
None => print!("{text}"),
Some(ascii) => {
if let Some(chr) = char::from_u32(ascii as u32) {
if chr.is_ascii() {
print!("{chr}")
}
}
}
}
let _ = std::io::stdout().flush();
}
}
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -395,39 +432,40 @@ fn main() -> anyhow::Result<()> {
} }
} }
let mut tokens = tokens.get_ids().to_vec(); let prompt_tokens = tokens.get_ids().to_vec();
let mut index_pos = 0;
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let start_gen = std::time::Instant::now();
let mut token_generated = 0;
print!("{prompt}"); print!("{prompt}");
for index in 0..args.sample_len {
let context_size = if index == 0 { tokens.len() } else { 1 }; let start_prompt_processing = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let mut next_token = {
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
index_pos += ctxt.len(); logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
print_token(next_token, &tokenizer);
let next_token = logits_processor.sample(&logits)?; let to_sample = args.sample_len.saturating_sub(1);
token_generated += 1; let start_post_prompt = std::time::Instant::now();
tokens.push(next_token); for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
// Extracting the last token as a string is complicated, here we just apply some simple let logits = model.forward(&input, prompt_tokens.len() + index)?;
// heuristics as it seems to work well enough for this example. See the following for more let logits = logits.squeeze(0)?;
// details: next_token = logits_processor.sample(&logits)?;
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 print_token(next_token, &tokenizer);
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
} }
let dt = start_gen.elapsed(); let dt = start_post_prompt.elapsed();
println!( println!(
"\n\n{} tokens generated ({:.2} token/s)\n", "\n\n{:4} prompt tokens processed: {:.2} token/s",
token_generated, prompt_tokens.len(),
token_generated as f64 / dt.as_secs_f64(), prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{:4} tokens generated: {:.2} token/s",
to_sample,
to_sample as f64 / dt.as_secs_f64(),
); );
Ok(()) Ok(())
} }