From df6f5240bae8d4279d9b857f06b75ec582aca30e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 22 Sep 2023 20:03:16 +0100 Subject: [PATCH] Complete the mixformer implementation. (#930) * Complete the mixformers implementation. * Tweak the attention. * Add the phi-1.5 example. * Improve the phi example. * Bugfix. * Get the phi example to work. --- candle-examples/examples/phi/README.md | 23 +++ candle-examples/examples/phi/main.rs | 163 +++++++++++++++++++ candle-transformers/src/models/mixformer.rs | 166 ++++++++++++++++---- 3 files changed, 321 insertions(+), 31 deletions(-) create mode 100644 candle-examples/examples/phi/README.md create mode 100644 candle-examples/examples/phi/main.rs diff --git a/candle-examples/examples/phi/README.md b/candle-examples/examples/phi/README.md new file mode 100644 index 00000000..8cf053bd --- /dev/null +++ b/candle-examples/examples/phi/README.md @@ -0,0 +1,23 @@ +# candle-starcoder: code generation model + +[phi-1.5](https://huggingface.co/microsoft/phi-1_5). + +## Running some example + +```bash +$ cargo run --example phi --release -- --prompt "def print_prime(n): " + +def print_prime(n): + print("Printing prime numbers") + for i in range(2, n+1): + if is_prime(i): + print(i) + +def is_prime(n): + if n <= 1: + return False + for i in range(2, int(math.sqrt(n))+1): + if n % i == 0: + return False + return True +``` diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs new file mode 100644 index 00000000..4b290cd8 --- /dev/null +++ b/candle-examples/examples/phi/main.rs @@ -0,0 +1,163 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as Model}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, +} + +impl TextGeneration { + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer, + logits_processor, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + println!("starting the inference loop"); + print!("{prompt}"); + std::io::stdout().flush()?; + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let mut new_tokens = vec![]; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input)?; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; + print!("{token}"); + std::io::stdout().flush()?; + } + let dt = start_gen.elapsed(); + println!( + "{sample_len} tokens generated ({:.3} token/s)", + sample_len as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + + #[arg(long, default_value = "microsoft/phi-1_5")] + model_id: String, + + #[arg(long, default_value = "refs/pr/18")] + revision: String, + + #[arg(long)] + weight_file: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = repo.get("tokenizer.json")?; + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], + None => ["model.safetensors"] + .iter() + .map(|f| repo.get(f)) + .collect::, _>>()?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let weights = filenames + .iter() + .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? })) + .collect::>>()?; + let weights = weights + .iter() + .map(|f| Ok(f.deserialize()?)) + .collect::>>()?; + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let vb = VarBuilder::from_safetensors(weights, DType::F32, &device); + let config = Config::v1_5(); + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 2674d34f..028c3567 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -1,10 +1,11 @@ -#![allow(unused)] /// MixFormer model. /// https://huggingface.co/microsoft/phi-1_5 /// https://arxiv.org/abs/2309.05463 -use candle::{DType, Device, Module, Result, Tensor, D}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; +const MAX_SEQ_LEN: usize = 4096; + // https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py #[derive(Debug, Clone, PartialEq)] pub struct Config { @@ -21,8 +22,8 @@ pub struct Config { pad_vocab_size_multiple: usize, } -impl Default for Config { - fn default() -> Self { +impl Config { + pub fn v1() -> Self { Self { vocab_size: 50304, n_positions: 2048, @@ -37,6 +38,22 @@ impl Default for Config { pad_vocab_size_multiple: 64, } } + + pub fn v1_5() -> Self { + Self { + vocab_size: 51200, + n_positions: 2048, + n_embd: 2048, + n_layer: 24, + n_inner: None, + n_head: 32, + rotary_dim: usize::min(32, 2048 / 32), + activation_function: Activation::Gelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } } #[derive(Debug)] @@ -58,7 +75,70 @@ impl Module for Embedding { } #[derive(Debug)] -struct RotaryEmbedding {} +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result { + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + qkv: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor, Tensor)> { + let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?; + if three != 3 { + candle::bail!("unexpected shape for qkv {:?}", qkv.shape()) + } + let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?; + let rotary_dim = rotary_dim * 2; + let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?; + let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?; + let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?; + let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?; + let q12 = q_rot.chunk(2, D::Minus1)?; + let k12 = k_rot.chunk(2, D::Minus1)?; + let (q1, q2) = (&q12[0], &q12[1]); + let (k1, k2) = (&k12[0], &k12[1]); + let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; + let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; + let q_rot = Tensor::cat( + &[ + (q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?, + (q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?, + ], + D::Minus1, + )?; + let k_rot = Tensor::cat( + &[ + (k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?, + (k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?, + ], + D::Minus1, + )?; + let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?; + let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?; + let v = qkv.i((.., .., 2))?; + Ok((q, k, v)) + } +} #[derive(Debug)] #[allow(clippy::upper_case_acronyms)] @@ -87,18 +167,6 @@ impl Module for MLP { } } -#[derive(Debug)] -struct SelfAttention { - causal: bool, - softmax_scale: f64, -} - -#[derive(Debug)] -struct CrossAttention { - causal: bool, - softmax_scale: f64, -} - #[derive(Debug)] struct CausalLMHead { ln: candle_nn::LayerNorm, @@ -126,7 +194,10 @@ impl Module for CausalLMHead { struct MHA { wqkv: candle_nn::Linear, out_proj: candle_nn::Linear, + rotary_emb: RotaryEmbedding, + kv_cache: Option<(Tensor, Tensor)>, head_dim: usize, + softmax_scale: f64, } impl MHA { @@ -135,23 +206,59 @@ impl MHA { let op_size = cfg.n_embd; let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; + let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?; + let softmax_scale = 1f64 / (head_dim as f64).sqrt(); Ok(Self { wqkv, out_proj, head_dim, + kv_cache: None, + rotary_emb, + softmax_scale, }) } -} -impl Module for MHA { - fn forward(&self, xs: &Tensor) -> Result { - let (b_size, seq_len, n_embd) = xs.dims3()?; + fn forward(&mut self, xs: &Tensor) -> Result { + let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self .wqkv .forward(xs)? .reshape((b_size, seq_len, 3, (), self.head_dim))?; - let context: Tensor = qkv; // TODO - context.flatten_from(D::Minus2)?.apply(&self.out_proj) + let seqlen_offset = match &self.kv_cache { + None => 0, + Some((prev_k, _)) => prev_k.dim(1)?, + }; + // In the python implementation, a single tensor is returned with the third axis of size 3. + let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &k], 1)?; + let v = Tensor::cat(&[prev_v, &v], 1)?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + // scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + let q = q.transpose(1, 2)?.flatten_to(1)?; // b*h, t, d + let k = k.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d + let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d + let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s + + // TODO: Add the causal mask. + // causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1) + // scores = scores + causal_mask.to(dtype=scores.dtype) + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + + // output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + // attn_weights: b*h,t,s, v: b*h,s,d + let attn_output = attn_weights.matmul(&v)?; + // b*h,t,d + let attn_output = attn_output + .reshape((b_size, (), seq_len, self.head_dim))? + .transpose(1, 2)? + .flatten_from(D::Minus2)?; + attn_output.apply(&self.out_proj) } } @@ -169,10 +276,8 @@ impl ParallelBlock { let mlp = MLP::new(cfg, vb.pp("mlp"))?; Ok(Self { ln, mixer, mlp }) } -} -impl Module for ParallelBlock { - fn forward(&self, xs: &Tensor) -> Result { + fn forward(&mut self, xs: &Tensor) -> Result { let residual = xs; let xs = xs.apply(&self.ln)?; let attn_outputs = self.mixer.forward(&xs)?; @@ -204,14 +309,13 @@ impl MixFormerSequentialForCausalLM { head, }) } -} -impl Module for MixFormerSequentialForCausalLM { - fn forward(&self, xs: &Tensor) -> Result { + pub fn forward(&mut self, xs: &Tensor) -> Result { + let (_b_size, seq_len) = xs.dims2()?; let mut xs = xs.apply(&self.embedding)?; - for block in self.blocks.iter() { + for block in self.blocks.iter_mut() { xs = block.forward(&xs)? } - xs.apply(&self.head) + xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) } }