Blip forward pass (#1141)

* More forward methods for the blip model.

* Blipping continues.
This commit is contained in:
Laurent Mazare
2023-10-21 10:19:23 +01:00
committed by GitHub
parent 34d9e91748
commit 94e3373883

View File

@ -104,10 +104,8 @@ impl Attention {
num_heads,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
let mixed_qkv = xs
.apply(&self.qkv)?
@ -119,6 +117,10 @@ impl Module for Attention {
let attention_scores = query.matmul(&key.t()?)?;
let attention_scores = (attention_scores * self.scale)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let attention_probs = match attn_mask {
None => attention_probs,
Some(attn_mask) => (attention_probs * attn_mask)?,
};
attention_probs
.matmul(&value)?
.permute((0, 2, 1, 3))?
@ -178,10 +180,15 @@ impl EncoderLayer {
})
}
fn forward(&self, xs: &Tensor, attention_mask: Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.layer_norm1)?;
todo!()
let xs = self.self_attn.forward(&xs, attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
xs + residual
}
}
@ -199,6 +206,14 @@ impl Encoder {
}
Ok(Self { layers })
}
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, attention_mask)?
}
Ok(xs)
}
}
#[derive(Debug, Clone)]
@ -222,6 +237,19 @@ impl VisionModel {
}
}
impl Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let encoder_outputs = self.encoder.forward(&xs, None)?;
let last_hidden_state = encoder_outputs.get(0)?;
last_hidden_state
.apply(&self.post_layernorm)?
.narrow(1, 0, 1)?
.squeeze(1)?
.apply(&self.post_layernorm)
}
}
#[derive(Debug, Clone)]
struct BlipForConditionalGeneration {
vision_model: VisionModel,
@ -238,4 +266,13 @@ impl BlipForConditionalGeneration {
text_decoder,
})
}
fn forward(
&self,
pixel_values: &Tensor,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
todo!()
}
}