From d73ca3d28e977dd86c890187adad71aa71645756 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 19 Aug 2023 20:12:07 +0100 Subject: [PATCH] Line up the llama.cpp implementation with the candle one. (#518) * Separate the prompt stats from the post-prompt ones in the quantized example. * Slightly nicer output printing. * Line up with the llama.cpp implementation. --- candle-examples/examples/quantized/main.rs | 116 ++++++++++++++------- 1 file changed, 77 insertions(+), 39 deletions(-) diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 7da7cf1c..45e75de1 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -81,15 +81,30 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result impl LayerWeights { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { let _enter = self.span_rot.enter(); - let (b_sz, _, seq_len, n_embd) = x.dims4()?; - let cos = self.cos.narrow(0, index_pos, seq_len)?; - let sin = self.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?; - let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?; - let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let cos = self + .cos + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let sin = self + .sin + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; Ok(rope) } @@ -172,9 +187,6 @@ impl ModelWeights { .to_dtype(DType::F32)? .reshape((MAX_SEQ_LEN, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; - // This is different from the paper, see: - // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 - let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; let cos = idx_theta.cos()?; let sin = idx_theta.sin()?; @@ -335,6 +347,31 @@ impl Args { } } +fn print_token(next_token: u32, tokenizer: &Tokenizer) { + // Extracting the last token as a string is complicated, here we just apply some simple + // heuristics as it seems to work well enough for this example. See the following for more + // details: + // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 + if let Some(text) = tokenizer.id_to_token(next_token) { + let text = text.replace('▁', " "); + let ascii = text + .strip_prefix("<0x") + .and_then(|t| t.strip_suffix('>')) + .and_then(|t| u8::from_str_radix(t, 16).ok()); + match ascii { + None => print!("{text}"), + Some(ascii) => { + if let Some(chr) = char::from_u32(ascii as u32) { + if chr.is_ascii() { + print!("{chr}") + } + } + } + } + let _ = std::io::stdout().flush(); + } +} + fn main() -> anyhow::Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -395,39 +432,40 @@ fn main() -> anyhow::Result<()> { } } - let mut tokens = tokens.get_ids().to_vec(); - let mut index_pos = 0; + let prompt_tokens = tokens.get_ids().to_vec(); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); - let start_gen = std::time::Instant::now(); - let mut token_generated = 0; + print!("{prompt}"); - for index in 0..args.sample_len { - let context_size = if index == 0 { tokens.len() } else { 1 }; - let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?; - let logits = model.forward(&input, index_pos)?; + + let start_prompt_processing = std::time::Instant::now(); + let mut next_token = { + let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; let logits = logits.squeeze(0)?; - index_pos += ctxt.len(); + logits_processor.sample(&logits)? + }; + let prompt_dt = start_prompt_processing.elapsed(); + print_token(next_token, &tokenizer); - let next_token = logits_processor.sample(&logits)?; - token_generated += 1; - tokens.push(next_token); - - // Extracting the last token as a string is complicated, here we just apply some simple - // heuristics as it seems to work well enough for this example. See the following for more - // details: - // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 - if let Some(text) = tokenizer.id_to_token(next_token) { - let text = text.replace('▁', " ").replace("<0x0A>", "\n"); - print!("{text}"); - std::io::stdout().flush()?; - } + let to_sample = args.sample_len.saturating_sub(1); + let start_post_prompt = std::time::Instant::now(); + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; + let logits = model.forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)?; + print_token(next_token, &tokenizer); } - let dt = start_gen.elapsed(); + let dt = start_post_prompt.elapsed(); println!( - "\n\n{} tokens generated ({:.2} token/s)\n", - token_generated, - token_generated as f64 / dt.as_secs_f64(), + "\n\n{:4} prompt tokens processed: {:.2} token/s", + prompt_tokens.len(), + prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{:4} tokens generated: {:.2} token/s", + to_sample, + to_sample as f64 / dt.as_secs_f64(), ); Ok(()) }