From 34d9e9174824cc0656e083364fe68b85666843e0 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 20 Oct 2023 22:09:11 +0100 Subject: [PATCH] Add the blip image captioning model (#1140) * Blip text model. * Blip vision bits. * Blippity. * More blip. --- candle-transformers/src/models/blip.rs | 241 ++++++++++++ candle-transformers/src/models/blip_text.rs | 350 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 2 + .../src/models/with_tracing.rs | 4 +- 4 files changed, 595 insertions(+), 2 deletions(-) create mode 100644 candle-transformers/src/models/blip.rs create mode 100644 candle-transformers/src/models/blip_text.rs diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs new file mode 100644 index 00000000..4c2ca44d --- /dev/null +++ b/candle-transformers/src/models/blip.rs @@ -0,0 +1,241 @@ +#![allow(unused)] +use super::blip_text; +use super::with_tracing::{conv2d, linear, Conv2d, Linear}; +use candle::{Module, Result, Tensor, D}; +use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder}; + +#[derive(Debug, Clone)] +struct VisionConfig { + hidden_size: usize, + intermediate_size: usize, + projection_dim: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + image_size: usize, + patch_size: usize, + hidden_act: candle_nn::Activation, + layer_norm_eps: f64, +} + +#[derive(Debug, Clone)] +struct Config { + text_config: blip_text::Config, + vision_config: VisionConfig, + projection_dim: usize, + image_text_hidden_size: usize, +} + +#[derive(Debug, Clone)] +struct VisionEmbeddings { + class_embedding: Tensor, + patch_embedding: Conv2d, + position_embedding: Tensor, + num_positions: usize, +} + +impl VisionEmbeddings { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let class_embedding = vb.get((1, 1, cfg.hidden_size), "class_embedding")?; + let conv_cfg = Conv2dConfig { + stride: cfg.patch_size, + ..Default::default() + }; + let patch_embedding = conv2d( + 3, + cfg.hidden_size, + cfg.patch_size, + conv_cfg, + vb.pp("patch_embedding"), + )?; + let num_patches1 = cfg.image_size / cfg.patch_size; + let num_patches = num_patches1 * num_patches1; + let num_positions = num_patches + 1; + let position_embedding = + vb.get((1, num_positions, cfg.hidden_size), "position_embedding")?; + Ok(Self { + class_embedding, + patch_embedding, + position_embedding, + num_positions, + }) + } +} + +impl Module for VisionEmbeddings { + fn forward(&self, xs: &Tensor) -> Result { + let target_dtype = xs.dtype(); + let b_size = xs.dim(0)?; + let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?; + let d = self.class_embedding.dim(D::Minus1)?; + let class_embeds = self + .class_embedding + .broadcast_as((b_size, 1, d))? + .to_dtype(target_dtype)?; + let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?; + let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?; + embeddings.broadcast_add(&position_embedding) + } +} + +#[derive(Debug, Clone)] +struct Attention { + qkv: Linear, + projection: Linear, + scale: f64, + embed_dim: usize, + head_dim: usize, + num_heads: usize, +} + +impl Attention { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let embed_dim = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let head_dim = embed_dim / num_heads; + let scale = 1f64 / (head_dim as f64).sqrt(); + let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?; + let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?; + Ok(Self { + qkv, + projection, + scale, + embed_dim, + head_dim, + num_heads, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b_sz, tgt_len, embed_dim) = xs.dims3()?; + let mixed_qkv = xs + .apply(&self.qkv)? + .reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))? + .permute((2, 0, 3, 1, 4))?; + let query = mixed_qkv.get(0)?; + let key = mixed_qkv.get(1)?; + let value = mixed_qkv.get(2)?; + let attention_scores = query.matmul(&key.t()?)?; + let attention_scores = (attention_scores * self.scale)?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs + .matmul(&value)? + .permute((0, 2, 1, 3))? + .flatten_from(D::Minus2)? + .apply(&self.projection) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + activation_fn: candle_nn::Activation, + fc1: Linear, + fc2: Linear, +} + +impl MLP { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; + let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + Ok(Self { + activation_fn: cfg.hidden_act, + fc1, + fc2, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.fc1)? + .apply(&self.activation_fn)? + .apply(&self.fc2) + } +} + +#[derive(Debug, Clone)] +struct EncoderLayer { + self_attn: Attention, + layer_norm1: LayerNorm, + mlp: MLP, + layer_norm2: LayerNorm, +} + +impl EncoderLayer { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let embed_dim = cfg.hidden_size; + let self_attn = Attention::new(cfg, vb.pp("self_attn"))?; + let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?; + let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + Ok(Self { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Tensor) -> Result { + let residual = xs; + let xs = xs.apply(&self.layer_norm1)?; + todo!() + } +} + +#[derive(Debug, Clone)] +struct Encoder { + layers: Vec, +} + +impl Encoder { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for i in 0..cfg.num_hidden_layers { + let layer = EncoderLayer::new(cfg, vb.pp(i))?; + layers.push(layer) + } + Ok(Self { layers }) + } +} + +#[derive(Debug, Clone)] +struct VisionModel { + embeddings: VisionEmbeddings, + encoder: Encoder, + post_layernorm: LayerNorm, +} + +impl VisionModel { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?; + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let post_layernorm = + layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?; + Ok(Self { + embeddings, + encoder, + post_layernorm, + }) + } +} + +#[derive(Debug, Clone)] +struct BlipForConditionalGeneration { + vision_model: VisionModel, + text_decoder: blip_text::TextLMHeadModel, +} + +impl BlipForConditionalGeneration { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?; + let text_decoder = + blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?; + Ok(Self { + vision_model, + text_decoder, + }) + } +} diff --git a/candle-transformers/src/models/blip_text.rs b/candle-transformers/src/models/blip_text.rs new file mode 100644 index 00000000..8b4fb4d1 --- /dev/null +++ b/candle-transformers/src/models/blip_text.rs @@ -0,0 +1,350 @@ +#![allow(unused)] +use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; +use candle::{Module, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, VarBuilder}; + +#[derive(Debug, Clone)] +pub struct Config { + vocab_size: usize, + hidden_size: usize, + encoder_hidden_size: usize, + intermediate_size: usize, + projection_dim: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + max_position_embeddings: usize, + hidden_act: candle_nn::Activation, + layer_norm_eps: f64, + is_decoder: bool, +} + +#[derive(Debug, Clone)] +struct TextEmbeddings { + word_embedddings: Embedding, + position_embeddings: Embedding, + layer_norm: LayerNorm, + position_ids: Tensor, +} + +impl TextEmbeddings { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let word_embedddings = + Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?; + let position_embeddings = Embedding::new( + cfg.max_position_embeddings, + cfg.hidden_size, + vb.pp("position_embeddings"), + )?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + let position_ids = + Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?; + Ok(Self { + word_embedddings, + position_embeddings, + layer_norm, + position_ids, + }) + } +} + +#[derive(Debug, Clone)] +struct TextSelfAttention { + query: Linear, + key: Linear, + value: Linear, + all_head_size: usize, + attention_head_size: usize, + num_attention_heads: usize, +} + +impl TextSelfAttention { + fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result { + let num_attention_heads = cfg.num_attention_heads; + let attention_head_size = cfg.hidden_size / num_attention_heads; + let all_head_size = cfg.num_attention_heads * attention_head_size; + let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?; + let in_size = if is_cross_attention { + cfg.encoder_hidden_size + } else { + cfg.hidden_size + }; + let key = linear(in_size, all_head_size, vb.pp("key"))?; + let value = linear(in_size, all_head_size, vb.pp("value"))?; + Ok(Self { + query, + key, + value, + all_head_size, + attention_head_size, + num_attention_heads, + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, _) = xs.dims3()?; + xs.reshape(( + b_size, + seq_len, + self.num_attention_heads, + self.attention_head_size, + ))? + .permute((0, 2, 1, 3)) + } +} + +#[derive(Debug, Clone)] +struct TextSelfOutput { + dense: Linear, + layer_norm: LayerNorm, +} + +impl TextSelfOutput { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { dense, layer_norm }) + } + + fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result { + (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm) + } +} + +#[derive(Debug, Clone)] +struct TextAttention { + self_: TextSelfAttention, + output: TextSelfOutput, +} + +impl TextAttention { + fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result { + let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?; + let output = TextSelfOutput::new(cfg, vb.pp("output"))?; + Ok(Self { self_, output }) + } +} + +#[derive(Debug, Clone)] +struct TextIntermediate { + dense: Linear, + intermediate_act_fn: candle_nn::Activation, +} + +impl TextIntermediate { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?; + Ok(Self { + dense, + intermediate_act_fn: cfg.hidden_act, + }) + } +} + +impl Module for TextIntermediate { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.dense)?.apply(&self.intermediate_act_fn) + } +} + +#[derive(Debug, Clone)] +struct TextOutput { + dense: Linear, + layer_norm: LayerNorm, +} + +impl TextOutput { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { dense, layer_norm }) + } + + fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result { + (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm) + } +} + +#[derive(Debug, Clone)] +struct TextLayer { + attention: TextAttention, + cross_attention: Option, + intermediate: TextIntermediate, + output: TextOutput, +} + +impl TextLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let attention = TextAttention::new(cfg, false, vb.pp("attention"))?; + let cross_attention = if cfg.is_decoder { + Some(TextAttention::new(cfg, true, vb.pp("attention"))?) + } else { + None + }; + let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?; + let output = TextOutput::new(cfg, vb.pp("output"))?; + Ok(Self { + attention, + cross_attention, + intermediate, + output, + }) + } +} + +impl Module for TextLayer { + fn forward(&self, xs: &Tensor) -> Result { + todo!() + } +} + +#[derive(Debug, Clone)] +struct TextEncoder { + layers: Vec, +} + +impl TextEncoder { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb = vb.pp("layer"); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for i in 0..cfg.num_hidden_layers { + let layer = TextLayer::new(cfg, vb.pp(i))?; + layers.push(layer) + } + Ok(Self { layers }) + } +} + +impl Module for TextEncoder { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = xs.apply(layer)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct TextPooler { + dense: Linear, +} + +impl TextPooler { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + Ok(Self { dense }) + } +} + +impl Module for TextPooler { + fn forward(&self, xs: &Tensor) -> Result { + xs.narrow(D::Minus1, 0, 1)? + .squeeze(D::Minus1)? + .apply(&self.dense)? + .tanh() + } +} + +#[derive(Debug, Clone)] +struct TextPredictionHeadTransform { + dense: Linear, + transform_act_fn: candle_nn::Activation, + layer_norm: LayerNorm, +} + +impl TextPredictionHeadTransform { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { + dense, + transform_act_fn: cfg.hidden_act, + layer_norm, + }) + } +} + +impl Module for TextPredictionHeadTransform { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.dense)? + .apply(&self.transform_act_fn)? + .apply(&self.layer_norm) + } +} + +#[derive(Debug, Clone)] +struct TextLMPredictionHead { + transform: TextPredictionHeadTransform, + decoder: Linear, + bias: Tensor, +} + +impl TextLMPredictionHead { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?; + let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?; + let bias = vb.get(cfg.vocab_size, "bias")?; + Ok(Self { + transform, + decoder, + bias, + }) + } +} + +impl Module for TextLMPredictionHead { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.transform)?.apply(&self.decoder) + } +} + +#[derive(Debug, Clone)] +struct TextOnlyMLMHead { + predictions: TextLMPredictionHead, +} + +impl TextOnlyMLMHead { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?; + Ok(Self { predictions }) + } +} + +impl Module for TextOnlyMLMHead { + fn forward(&self, xs: &Tensor) -> Result { + self.predictions.forward(xs) + } +} + +#[derive(Debug, Clone)] +struct TextModel { + embeddings: TextEmbeddings, + encoder: TextEncoder, + pooler: Option, +} + +impl TextModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?; + let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?; + Ok(Self { + embeddings, + encoder, + pooler: None, + }) + } +} + +#[derive(Debug, Clone)] +pub struct TextLMHeadModel { + bert: TextModel, + cls: TextOnlyMLMHead, +} + +impl TextLMHeadModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let bert = TextModel::new(cfg, vb.pp("bert"))?; + let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?; + Ok(Self { bert, cls }) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 88c622d8..6836b9c0 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1,5 +1,7 @@ pub mod bert; pub mod bigcode; +pub mod blip; +pub mod blip_text; pub mod convmixer; pub mod dinov2; pub mod efficientnet; diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index edd8d657..69654139 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -58,8 +58,8 @@ pub struct Conv2d { span: tracing::Span, } -impl Conv2d { - pub fn forward(&self, x: &Tensor) -> Result { +impl Module for Conv2d { + fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); self.inner.forward(x) }