Blip attention mask + readme (#1146)

* Add the attention mask to the blip model.

* Add a readme.
This commit is contained in:
Laurent Mazare
2023-10-21 22:44:13 +01:00
committed by GitHub
parent 2531b13bf8
commit 3115fe42e4
2 changed files with 68 additions and 13 deletions

View File

@ -105,7 +105,12 @@ impl TextSelfAttention {
.permute((0, 2, 1, 3))
}
fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
fn forward(
&self,
xs: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let query = self
.transpose_for_scores(&self.query.forward(xs)?)?
.contiguous()?;
@ -127,6 +132,10 @@ impl TextSelfAttention {
let value = value.contiguous()?;
let attention_scores = query.matmul(&key.t()?)?;
let attention_scores = (attention_scores * self.attention_scale)?;
let attention_scores = match attention_mask {
Some(mask) => attention_scores.broadcast_add(mask)?,
None => attention_scores,
};
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
attention_probs
.matmul(&value)?
@ -166,8 +175,15 @@ impl TextAttention {
Ok(Self { self_, output })
}
fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
let self_outputs = self.self_.forward(xs, encoder_hidden_states)?;
fn forward(
&self,
xs: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let self_outputs = self
.self_
.forward(xs, encoder_hidden_states, attention_mask)?;
self.output.forward(&self_outputs, xs)
}
}
@ -238,10 +254,15 @@ impl TextLayer {
})
}
fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let attention_output = self.attention.forward(xs, None)?;
fn forward(
&self,
xs: &Tensor,
encoder_hidden_states: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
let attention_output = match &self.cross_attention {
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states))?,
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
None => candle::bail!("expected some cross-attn"),
};
let intermediate_output = self.intermediate.forward(&attention_output)?;
@ -265,10 +286,15 @@ impl TextEncoder {
Ok(Self { layers })
}
fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
fn forward(
&self,
xs: &Tensor,
encoder_hidden_states: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, encoder_hidden_states)?
xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
}
Ok(xs)
}
@ -384,11 +410,16 @@ impl TextModel {
})
}
fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
fn forward(
&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(input_ids)?;
let sequence_output = self
.encoder
.forward(&embedding_output, encoder_hidden_states)?;
let sequence_output =
self.encoder
.forward(&embedding_output, encoder_hidden_states, attention_mask)?;
// We're interested in the sequence-output rather than the pooled-output.
Ok(sequence_output)
}
@ -408,7 +439,12 @@ impl TextLMHeadModel {
}
pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let sequence_output = self.bert.forward(input_ids, encoder_hidden_states)?;
let seq_len = input_ids.dim(1)?;
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
let prediction_scores = self.cls.forward(&sequence_output)?;
// return_logits is false so we don't discard the last sequence element.
Ok(prediction_scores)