diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs new file mode 100644 index 00000000..2674d34f --- /dev/null +++ b/candle-transformers/src/models/mixformer.rs @@ -0,0 +1,217 @@ +#![allow(unused)] +/// MixFormer model. +/// https://huggingface.co/microsoft/phi-1_5 +/// https://arxiv.org/abs/2309.05463 +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; + +// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py +#[derive(Debug, Clone, PartialEq)] +pub struct Config { + vocab_size: usize, + n_positions: usize, + n_embd: usize, + n_layer: usize, + n_inner: Option, + n_head: usize, + rotary_dim: usize, + activation_function: Activation, + layer_norm_epsilon: f64, + tie_word_embeddings: bool, + pad_vocab_size_multiple: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 50304, + n_positions: 2048, + n_embd: 1024, + n_layer: 20, + n_inner: None, + n_head: 16, + rotary_dim: usize::min(32, 1024 / 16), + activation_function: Activation::Gelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } +} + +#[derive(Debug)] +struct Embedding { + wte: candle_nn::Embedding, +} + +impl Embedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?; + Ok(Self { wte }) + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result { + self.wte.forward(xs) + } +} + +#[derive(Debug)] +struct RotaryEmbedding {} + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + fc1: candle_nn::Linear, + fc2: candle_nn::Linear, + act: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd); + let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?; + let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?; + Ok(Self { + fc1, + fc2, + act: cfg.activation_function, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) + } +} + +#[derive(Debug)] +struct SelfAttention { + causal: bool, + softmax_scale: f64, +} + +#[derive(Debug)] +struct CrossAttention { + causal: bool, + softmax_scale: f64, +} + +#[derive(Debug)] +struct CausalLMHead { + ln: candle_nn::LayerNorm, + linear: candle_nn::Linear, +} + +impl CausalLMHead { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?; + let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?; + Ok(Self { ln, linear }) + } +} + +impl Module for CausalLMHead { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.ln)? + .apply(&self.linear)? + .to_dtype(DType::F32) + } +} + +#[derive(Debug)] +#[allow(clippy::upper_case_acronyms)] +struct MHA { + wqkv: candle_nn::Linear, + out_proj: candle_nn::Linear, + head_dim: usize, +} + +impl MHA { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let head_dim = cfg.n_embd / cfg.n_head; + let op_size = cfg.n_embd; + let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; + let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; + Ok(Self { + wqkv, + out_proj, + head_dim, + }) + } +} + +impl Module for MHA { + fn forward(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, n_embd) = xs.dims3()?; + let qkv = self + .wqkv + .forward(xs)? + .reshape((b_size, seq_len, 3, (), self.head_dim))?; + let context: Tensor = qkv; // TODO + context.flatten_from(D::Minus2)?.apply(&self.out_proj) + } +} + +#[derive(Debug)] +struct ParallelBlock { + ln: candle_nn::LayerNorm, + mixer: MHA, + mlp: MLP, +} + +impl ParallelBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?; + let mixer = MHA::new(cfg, vb.pp("mixer"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + Ok(Self { ln, mixer, mlp }) + } +} + +impl Module for ParallelBlock { + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = xs.apply(&self.ln)?; + let attn_outputs = self.mixer.forward(&xs)?; + let feed_forward_hidden_states = self.mlp.forward(&xs)?; + attn_outputs + feed_forward_hidden_states + residual + } +} + +#[derive(Debug)] +pub struct MixFormerSequentialForCausalLM { + embedding: Embedding, + blocks: Vec, + head: CausalLMHead, +} + +impl MixFormerSequentialForCausalLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb = vb.pp("layers"); + let embedding = Embedding::new(cfg, vb.pp(0))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp(i + 1))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?; + Ok(Self { + embedding, + blocks, + head, + }) + } +} + +impl Module for MixFormerSequentialForCausalLM { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.apply(&self.embedding)?; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + xs.apply(&self.head) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index d783a2c6..991ee201 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -4,6 +4,7 @@ pub mod dinov2; pub mod efficientnet; pub mod falcon; pub mod llama; +pub mod mixformer; pub mod quantized_llama; pub mod quantized_t5; pub mod segment_anything;