Add the blip example. (#1144)

* Add the blip example.

* Tweak the example.

* Implement the cross-attn logic.

* Fix some shape mismatches.

* Get some logits out.

* Get some caption to be generated.
This commit is contained in:
Laurent Mazare
2023-10-21 20:05:02 +01:00
committed by GitHub
parent e8f760ee44
commit 0d9bb4eb18
3 changed files with 223 additions and 45 deletions

View File

@ -5,17 +5,17 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
pub struct Config {
vocab_size: usize,
hidden_size: usize,
encoder_hidden_size: usize,
intermediate_size: usize,
projection_dim: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
max_position_embeddings: usize,
hidden_act: candle_nn::Activation,
layer_norm_eps: f64,
is_decoder: bool,
pub vocab_size: usize,
pub hidden_size: usize,
pub encoder_hidden_size: usize,
pub intermediate_size: usize,
pub projection_dim: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub max_position_embeddings: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub is_decoder: bool,
}
#[derive(Debug, Clone)]
@ -47,6 +47,17 @@ impl TextEmbeddings {
}
}
impl Module for TextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let seq_len = xs.dim(1)?;
// Use past_key_values_length if we add a kv cache.
let position_ids = self.position_ids.narrow(1, 0, seq_len)?;
let embeddings = self.word_embedddings.forward(xs)?;
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
(embeddings + position_embeddings)?.apply(&self.layer_norm)
}
}
#[derive(Debug, Clone)]
struct TextSelfAttention {
query: Linear,
@ -55,6 +66,7 @@ struct TextSelfAttention {
all_head_size: usize,
attention_head_size: usize,
num_attention_heads: usize,
attention_scale: f64,
}
impl TextSelfAttention {
@ -70,6 +82,7 @@ impl TextSelfAttention {
};
let key = linear(in_size, all_head_size, vb.pp("key"))?;
let value = linear(in_size, all_head_size, vb.pp("value"))?;
let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
Ok(Self {
query,
key,
@ -77,6 +90,7 @@ impl TextSelfAttention {
all_head_size,
attention_head_size,
num_attention_heads,
attention_scale,
})
}
@ -90,6 +104,35 @@ impl TextSelfAttention {
))?
.permute((0, 2, 1, 3))
}
fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
let query = self
.transpose_for_scores(&self.query.forward(xs)?)?
.contiguous()?;
let (key, value) = match encoder_hidden_states {
None => {
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
// TODO: kv cache
(key, value)
}
Some(xs) => {
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
// no kv-cache in this case, but the results could probably be memoized.
(key, value)
}
};
let key = key.contiguous()?;
let value = value.contiguous()?;
let attention_scores = query.matmul(&key.t()?)?;
let attention_scores = (attention_scores * self.attention_scale)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
attention_probs
.matmul(&value)?
.permute((0, 2, 1, 3))?
.flatten_from(D::Minus2)
}
}
#[derive(Debug, Clone)]
@ -122,6 +165,11 @@ impl TextAttention {
let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
Ok(Self { self_, output })
}
fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
let self_outputs = self.self_.forward(xs, encoder_hidden_states)?;
self.output.forward(&self_outputs, xs)
}
}
#[derive(Debug, Clone)]
@ -176,7 +224,7 @@ impl TextLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
let cross_attention = if cfg.is_decoder {
Some(TextAttention::new(cfg, true, vb.pp("attention"))?)
Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
} else {
None
};
@ -189,11 +237,15 @@ impl TextLayer {
output,
})
}
}
impl Module for TextLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
todo!()
fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let attention_output = self.attention.forward(xs, None)?;
let attention_output = match &self.cross_attention {
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states))?,
None => candle::bail!("expected some cross-attn"),
};
let intermediate_output = self.intermediate.forward(&attention_output)?;
self.output.forward(&intermediate_output, &attention_output)
}
}
@ -212,13 +264,11 @@ impl TextEncoder {
}
Ok(Self { layers })
}
}
impl Module for TextEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = xs.apply(layer)?
xs = layer.forward(&xs, encoder_hidden_states)?
}
Ok(xs)
}
@ -333,6 +383,15 @@ impl TextModel {
pooler: None,
})
}
fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(input_ids)?;
let sequence_output = self
.encoder
.forward(&embedding_output, encoder_hidden_states)?;
// We're interested in the sequence-output rather than the pooled-output.
Ok(sequence_output)
}
}
#[derive(Debug, Clone)]
@ -347,4 +406,11 @@ impl TextLMHeadModel {
let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
Ok(Self { bert, cls })
}
pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
let sequence_output = self.bert.forward(input_ids, encoder_hidden_states)?;
let prediction_scores = self.cls.forward(&sequence_output)?;
// return_logits is false so we don't discard the last sequence element.
Ok(prediction_scores)
}
}