mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Mixformer (#929)
* Sketch the mixformer model. * More modeling code. * More mixformers. * MixFormer creation. * More mixformers.
This commit is contained in:
217
candle-transformers/src/models/mixformer.rs
Normal file
217
candle-transformers/src/models/mixformer.rs
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
/// MixFormer model.
|
||||||
|
/// https://huggingface.co/microsoft/phi-1_5
|
||||||
|
/// https://arxiv.org/abs/2309.05463
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{Activation, VarBuilder};
|
||||||
|
|
||||||
|
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct Config {
|
||||||
|
vocab_size: usize,
|
||||||
|
n_positions: usize,
|
||||||
|
n_embd: usize,
|
||||||
|
n_layer: usize,
|
||||||
|
n_inner: Option<usize>,
|
||||||
|
n_head: usize,
|
||||||
|
rotary_dim: usize,
|
||||||
|
activation_function: Activation,
|
||||||
|
layer_norm_epsilon: f64,
|
||||||
|
tie_word_embeddings: bool,
|
||||||
|
pad_vocab_size_multiple: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 50304,
|
||||||
|
n_positions: 2048,
|
||||||
|
n_embd: 1024,
|
||||||
|
n_layer: 20,
|
||||||
|
n_inner: None,
|
||||||
|
n_head: 16,
|
||||||
|
rotary_dim: usize::min(32, 1024 / 16),
|
||||||
|
activation_function: Activation::Gelu,
|
||||||
|
layer_norm_epsilon: 1e-5,
|
||||||
|
tie_word_embeddings: false,
|
||||||
|
pad_vocab_size_multiple: 64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Embedding {
|
||||||
|
wte: candle_nn::Embedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||||
|
Ok(Self { wte })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Embedding {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.wte.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct RotaryEmbedding {}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
fc1: candle_nn::Linear,
|
||||||
|
fc2: candle_nn::Linear,
|
||||||
|
act: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
|
||||||
|
let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
|
||||||
|
let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
act: cfg.activation_function,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SelfAttention {
|
||||||
|
causal: bool,
|
||||||
|
softmax_scale: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct CrossAttention {
|
||||||
|
causal: bool,
|
||||||
|
softmax_scale: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct CausalLMHead {
|
||||||
|
ln: candle_nn::LayerNorm,
|
||||||
|
linear: candle_nn::Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CausalLMHead {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||||
|
let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
|
||||||
|
Ok(Self { ln, linear })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for CausalLMHead {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.ln)?
|
||||||
|
.apply(&self.linear)?
|
||||||
|
.to_dtype(DType::F32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MHA {
|
||||||
|
wqkv: candle_nn::Linear,
|
||||||
|
out_proj: candle_nn::Linear,
|
||||||
|
head_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MHA {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let head_dim = cfg.n_embd / cfg.n_head;
|
||||||
|
let op_size = cfg.n_embd;
|
||||||
|
let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||||
|
let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
wqkv,
|
||||||
|
out_proj,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MHA {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len, n_embd) = xs.dims3()?;
|
||||||
|
let qkv = self
|
||||||
|
.wqkv
|
||||||
|
.forward(xs)?
|
||||||
|
.reshape((b_size, seq_len, 3, (), self.head_dim))?;
|
||||||
|
let context: Tensor = qkv; // TODO
|
||||||
|
context.flatten_from(D::Minus2)?.apply(&self.out_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ParallelBlock {
|
||||||
|
ln: candle_nn::LayerNorm,
|
||||||
|
mixer: MHA,
|
||||||
|
mlp: MLP,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ParallelBlock {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||||
|
let mixer = MHA::new(cfg, vb.pp("mixer"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
Ok(Self { ln, mixer, mlp })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ParallelBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = xs.apply(&self.ln)?;
|
||||||
|
let attn_outputs = self.mixer.forward(&xs)?;
|
||||||
|
let feed_forward_hidden_states = self.mlp.forward(&xs)?;
|
||||||
|
attn_outputs + feed_forward_hidden_states + residual
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct MixFormerSequentialForCausalLM {
|
||||||
|
embedding: Embedding,
|
||||||
|
blocks: Vec<ParallelBlock>,
|
||||||
|
head: CausalLMHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MixFormerSequentialForCausalLM {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb = vb.pp("layers");
|
||||||
|
let embedding = Embedding::new(cfg, vb.pp(0))?;
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
for i in 0..cfg.n_layer {
|
||||||
|
let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
|
||||||
|
Ok(Self {
|
||||||
|
embedding,
|
||||||
|
blocks,
|
||||||
|
head,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MixFormerSequentialForCausalLM {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.apply(&self.embedding)?;
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
xs = block.forward(&xs)?
|
||||||
|
}
|
||||||
|
xs.apply(&self.head)
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ pub mod dinov2;
|
|||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
pub mod falcon;
|
pub mod falcon;
|
||||||
pub mod llama;
|
pub mod llama;
|
||||||
|
pub mod mixformer;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
pub mod quantized_t5;
|
pub mod quantized_t5;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
|
Reference in New Issue
Block a user