mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Blip forward pass (#1141)
* More forward methods for the blip model. * Blipping continues.
This commit is contained in:
@ -104,10 +104,8 @@ impl Attention {
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
|
||||
let mixed_qkv = xs
|
||||
.apply(&self.qkv)?
|
||||
@ -119,6 +117,10 @@ impl Module for Attention {
|
||||
let attention_scores = query.matmul(&key.t()?)?;
|
||||
let attention_scores = (attention_scores * self.scale)?;
|
||||
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
let attention_probs = match attn_mask {
|
||||
None => attention_probs,
|
||||
Some(attn_mask) => (attention_probs * attn_mask)?,
|
||||
};
|
||||
attention_probs
|
||||
.matmul(&value)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
@ -178,10 +180,15 @@ impl EncoderLayer {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.layer_norm1)?;
|
||||
todo!()
|
||||
let xs = self.self_attn.forward(&xs, attention_mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,6 +206,14 @@ impl Encoder {
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -222,6 +237,19 @@ impl VisionModel {
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionModel {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embeddings)?;
|
||||
let encoder_outputs = self.encoder.forward(&xs, None)?;
|
||||
let last_hidden_state = encoder_outputs.get(0)?;
|
||||
last_hidden_state
|
||||
.apply(&self.post_layernorm)?
|
||||
.narrow(1, 0, 1)?
|
||||
.squeeze(1)?
|
||||
.apply(&self.post_layernorm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BlipForConditionalGeneration {
|
||||
vision_model: VisionModel,
|
||||
@ -238,4 +266,13 @@ impl BlipForConditionalGeneration {
|
||||
text_decoder,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user