mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add the blip example. (#1144)
* Add the blip example. * Tweak the example. * Implement the cross-attn logic. * Fix some shape mismatches. * Get some logits out. * Get some caption to be generated.
This commit is contained in:
@ -5,24 +5,59 @@ use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionConfig {
|
||||
hidden_size: usize,
|
||||
intermediate_size: usize,
|
||||
projection_dim: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
hidden_act: candle_nn::Activation,
|
||||
layer_norm_eps: f64,
|
||||
pub struct VisionConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub projection_dim: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Config {
|
||||
text_config: blip_text::Config,
|
||||
vision_config: VisionConfig,
|
||||
projection_dim: usize,
|
||||
image_text_hidden_size: usize,
|
||||
pub struct Config {
|
||||
pub text_config: blip_text::Config,
|
||||
pub vision_config: VisionConfig,
|
||||
pub projection_dim: usize,
|
||||
pub image_text_hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn image_captioning_large() -> Self {
|
||||
let text_config = blip_text::Config {
|
||||
vocab_size: 30524,
|
||||
hidden_size: 768,
|
||||
encoder_hidden_size: 1024,
|
||||
intermediate_size: 3072,
|
||||
projection_dim: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
max_position_embeddings: 512,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
layer_norm_eps: 1e-12,
|
||||
is_decoder: true,
|
||||
};
|
||||
let vision_config = VisionConfig {
|
||||
hidden_size: 1024,
|
||||
intermediate_size: 4096,
|
||||
projection_dim: 512,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
image_size: 384,
|
||||
patch_size: 16,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
layer_norm_eps: 1e-5,
|
||||
};
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
projection_dim: 512,
|
||||
image_text_hidden_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -200,6 +235,7 @@ struct Encoder {
|
||||
impl Encoder {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
let layer = EncoderLayer::new(cfg, vb.pp(i))?;
|
||||
layers.push(layer)
|
||||
@ -217,7 +253,7 @@ impl Encoder {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionModel {
|
||||
pub struct VisionModel {
|
||||
embeddings: VisionEmbeddings,
|
||||
encoder: Encoder,
|
||||
post_layernorm: LayerNorm,
|
||||
@ -241,23 +277,19 @@ impl Module for VisionModel {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embeddings)?;
|
||||
let encoder_outputs = self.encoder.forward(&xs, None)?;
|
||||
let last_hidden_state = encoder_outputs.get(0)?;
|
||||
last_hidden_state
|
||||
.apply(&self.post_layernorm)?
|
||||
.narrow(1, 0, 1)?
|
||||
.squeeze(1)?
|
||||
.apply(&self.post_layernorm)
|
||||
// Return the last hidden state rather than pooled outputs.
|
||||
encoder_outputs.apply(&self.post_layernorm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BlipForConditionalGeneration {
|
||||
pub struct BlipForConditionalGeneration {
|
||||
vision_model: VisionModel,
|
||||
text_decoder: blip_text::TextLMHeadModel,
|
||||
}
|
||||
|
||||
impl BlipForConditionalGeneration {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
|
||||
let text_decoder =
|
||||
blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
|
||||
@ -267,12 +299,38 @@ impl BlipForConditionalGeneration {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn vision_model(&self) -> &VisionModel {
|
||||
&self.vision_model
|
||||
}
|
||||
|
||||
pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
|
||||
&self.text_decoder
|
||||
}
|
||||
|
||||
pub fn generate(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let image_embeds = pixel_values.apply(&self.vision_model)?;
|
||||
let b_size = image_embeds.dim(0)?;
|
||||
if b_size > 1 {
|
||||
candle::bail!("only a batch size of 1 is supported")
|
||||
}
|
||||
let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None);
|
||||
let mut token_ids = vec![30522u32];
|
||||
for i in 0..1000 {
|
||||
let input_ids =
|
||||
Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?;
|
||||
let logits = self.text_decoder.forward(&input_ids, &image_embeds)?;
|
||||
println!("{logits:?}");
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
println!("{token}");
|
||||
token_ids.push(token)
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
@ -5,17 +5,17 @@ use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
encoder_hidden_size: usize,
|
||||
intermediate_size: usize,
|
||||
projection_dim: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
max_position_embeddings: usize,
|
||||
hidden_act: candle_nn::Activation,
|
||||
layer_norm_eps: f64,
|
||||
is_decoder: bool,
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub encoder_hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub projection_dim: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub is_decoder: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -47,6 +47,17 @@ impl TextEmbeddings {
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
// Use past_key_values_length if we add a kv cache.
|
||||
let position_ids = self.position_ids.narrow(1, 0, seq_len)?;
|
||||
let embeddings = self.word_embedddings.forward(xs)?;
|
||||
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
|
||||
(embeddings + position_embeddings)?.apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextSelfAttention {
|
||||
query: Linear,
|
||||
@ -55,6 +66,7 @@ struct TextSelfAttention {
|
||||
all_head_size: usize,
|
||||
attention_head_size: usize,
|
||||
num_attention_heads: usize,
|
||||
attention_scale: f64,
|
||||
}
|
||||
|
||||
impl TextSelfAttention {
|
||||
@ -70,6 +82,7 @@ impl TextSelfAttention {
|
||||
};
|
||||
let key = linear(in_size, all_head_size, vb.pp("key"))?;
|
||||
let value = linear(in_size, all_head_size, vb.pp("value"))?;
|
||||
let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -77,6 +90,7 @@ impl TextSelfAttention {
|
||||
all_head_size,
|
||||
attention_head_size,
|
||||
num_attention_heads,
|
||||
attention_scale,
|
||||
})
|
||||
}
|
||||
|
||||
@ -90,6 +104,35 @@ impl TextSelfAttention {
|
||||
))?
|
||||
.permute((0, 2, 1, 3))
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
|
||||
let query = self
|
||||
.transpose_for_scores(&self.query.forward(xs)?)?
|
||||
.contiguous()?;
|
||||
let (key, value) = match encoder_hidden_states {
|
||||
None => {
|
||||
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
|
||||
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
|
||||
// TODO: kv cache
|
||||
(key, value)
|
||||
}
|
||||
Some(xs) => {
|
||||
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
|
||||
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
|
||||
// no kv-cache in this case, but the results could probably be memoized.
|
||||
(key, value)
|
||||
}
|
||||
};
|
||||
let key = key.contiguous()?;
|
||||
let value = value.contiguous()?;
|
||||
let attention_scores = query.matmul(&key.t()?)?;
|
||||
let attention_scores = (attention_scores * self.attention_scale)?;
|
||||
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
attention_probs
|
||||
.matmul(&value)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.flatten_from(D::Minus2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -122,6 +165,11 @@ impl TextAttention {
|
||||
let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
|
||||
Ok(Self { self_, output })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, encoder_hidden_states: Option<&Tensor>) -> Result<Tensor> {
|
||||
let self_outputs = self.self_.forward(xs, encoder_hidden_states)?;
|
||||
self.output.forward(&self_outputs, xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -176,7 +224,7 @@ impl TextLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
|
||||
let cross_attention = if cfg.is_decoder {
|
||||
Some(TextAttention::new(cfg, true, vb.pp("attention"))?)
|
||||
Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -189,11 +237,15 @@ impl TextLayer {
|
||||
output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextLayer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(xs, None)?;
|
||||
let attention_output = match &self.cross_attention {
|
||||
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states))?,
|
||||
None => candle::bail!("expected some cross-attn"),
|
||||
};
|
||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||
self.output.forward(&intermediate_output, &attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
@ -212,13 +264,11 @@ impl TextEncoder {
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = xs.apply(layer)?
|
||||
xs = layer.forward(&xs, encoder_hidden_states)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
@ -333,6 +383,15 @@ impl TextModel {
|
||||
pooler: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let sequence_output = self
|
||||
.encoder
|
||||
.forward(&embedding_output, encoder_hidden_states)?;
|
||||
// We're interested in the sequence-output rather than the pooled-output.
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -347,4 +406,11 @@ impl TextLMHeadModel {
|
||||
let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
|
||||
Ok(Self { bert, cls })
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let sequence_output = self.bert.forward(input_ids, encoder_hidden_states)?;
|
||||
let prediction_scores = self.cls.forward(&sequence_output)?;
|
||||
// return_logits is false so we don't discard the last sequence element.
|
||||
Ok(prediction_scores)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user