From 94e3373883caaa7442201dac25abe16b4469f9bd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 21 Oct 2023 10:19:23 +0100 Subject: [PATCH] Blip forward pass (#1141) * More forward methods for the blip model. * Blipping continues. --- candle-transformers/src/models/blip.rs | 47 +++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/candle-transformers/src/models/blip.rs b/candle-transformers/src/models/blip.rs index 4c2ca44d..dd1bcd48 100644 --- a/candle-transformers/src/models/blip.rs +++ b/candle-transformers/src/models/blip.rs @@ -104,10 +104,8 @@ impl Attention { num_heads, }) } -} -impl Module for Attention { - fn forward(&self, xs: &Tensor) -> Result { + fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result { let (b_sz, tgt_len, embed_dim) = xs.dims3()?; let mixed_qkv = xs .apply(&self.qkv)? @@ -119,6 +117,10 @@ impl Module for Attention { let attention_scores = query.matmul(&key.t()?)?; let attention_scores = (attention_scores * self.scale)?; let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + let attention_probs = match attn_mask { + None => attention_probs, + Some(attn_mask) => (attention_probs * attn_mask)?, + }; attention_probs .matmul(&value)? .permute((0, 2, 1, 3))? @@ -178,10 +180,15 @@ impl EncoderLayer { }) } - fn forward(&self, xs: &Tensor, attention_mask: Tensor) -> Result { + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { let residual = xs; let xs = xs.apply(&self.layer_norm1)?; - todo!() + let xs = self.self_attn.forward(&xs, attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?; + xs + residual } } @@ -199,6 +206,14 @@ impl Encoder { } Ok(Self { layers }) } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, attention_mask)? + } + Ok(xs) + } } #[derive(Debug, Clone)] @@ -222,6 +237,19 @@ impl VisionModel { } } +impl Module for VisionModel { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.embeddings)?; + let encoder_outputs = self.encoder.forward(&xs, None)?; + let last_hidden_state = encoder_outputs.get(0)?; + last_hidden_state + .apply(&self.post_layernorm)? + .narrow(1, 0, 1)? + .squeeze(1)? + .apply(&self.post_layernorm) + } +} + #[derive(Debug, Clone)] struct BlipForConditionalGeneration { vision_model: VisionModel, @@ -238,4 +266,13 @@ impl BlipForConditionalGeneration { text_decoder, }) } + + fn forward( + &self, + pixel_values: &Tensor, + input_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result { + todo!() + } }