mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Add the blip example. (#1144)
* Add the blip example. * Tweak the example. * Implement the cross-attn logic. * Fix some shape mismatches. * Get some logits out. * Get some caption to be generated.
This commit is contained in:
@ -5,24 +5,59 @@ use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionConfig {
|
||||
hidden_size: usize,
|
||||
intermediate_size: usize,
|
||||
projection_dim: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
hidden_act: candle_nn::Activation,
|
||||
layer_norm_eps: f64,
|
||||
pub struct VisionConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub projection_dim: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Config {
|
||||
text_config: blip_text::Config,
|
||||
vision_config: VisionConfig,
|
||||
projection_dim: usize,
|
||||
image_text_hidden_size: usize,
|
||||
pub struct Config {
|
||||
pub text_config: blip_text::Config,
|
||||
pub vision_config: VisionConfig,
|
||||
pub projection_dim: usize,
|
||||
pub image_text_hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn image_captioning_large() -> Self {
|
||||
let text_config = blip_text::Config {
|
||||
vocab_size: 30524,
|
||||
hidden_size: 768,
|
||||
encoder_hidden_size: 1024,
|
||||
intermediate_size: 3072,
|
||||
projection_dim: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
max_position_embeddings: 512,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
layer_norm_eps: 1e-12,
|
||||
is_decoder: true,
|
||||
};
|
||||
let vision_config = VisionConfig {
|
||||
hidden_size: 1024,
|
||||
intermediate_size: 4096,
|
||||
projection_dim: 512,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
image_size: 384,
|
||||
patch_size: 16,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
layer_norm_eps: 1e-5,
|
||||
};
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
projection_dim: 512,
|
||||
image_text_hidden_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -200,6 +235,7 @@ struct Encoder {
|
||||
impl Encoder {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
let layer = EncoderLayer::new(cfg, vb.pp(i))?;
|
||||
layers.push(layer)
|
||||
@ -217,7 +253,7 @@ impl Encoder {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionModel {
|
||||
pub struct VisionModel {
|
||||
embeddings: VisionEmbeddings,
|
||||
encoder: Encoder,
|
||||
post_layernorm: LayerNorm,
|
||||
@ -241,23 +277,19 @@ impl Module for VisionModel {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embeddings)?;
|
||||
let encoder_outputs = self.encoder.forward(&xs, None)?;
|
||||
let last_hidden_state = encoder_outputs.get(0)?;
|
||||
last_hidden_state
|
||||
.apply(&self.post_layernorm)?
|
||||
.narrow(1, 0, 1)?
|
||||
.squeeze(1)?
|
||||
.apply(&self.post_layernorm)
|
||||
// Return the last hidden state rather than pooled outputs.
|
||||
encoder_outputs.apply(&self.post_layernorm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BlipForConditionalGeneration {
|
||||
pub struct BlipForConditionalGeneration {
|
||||
vision_model: VisionModel,
|
||||
text_decoder: blip_text::TextLMHeadModel,
|
||||
}
|
||||
|
||||
impl BlipForConditionalGeneration {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
|
||||
let text_decoder =
|
||||
blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
|
||||
@ -267,12 +299,38 @@ impl BlipForConditionalGeneration {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
pub fn vision_model(&self) -> &VisionModel {
|
||||
&self.vision_model
|
||||
}
|
||||
|
||||
pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
|
||||
&self.text_decoder
|
||||
}
|
||||
|
||||
pub fn generate(
|
||||
&self,
|
||||
pixel_values: &Tensor,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let image_embeds = pixel_values.apply(&self.vision_model)?;
|
||||
let b_size = image_embeds.dim(0)?;
|
||||
if b_size > 1 {
|
||||
candle::bail!("only a batch size of 1 is supported")
|
||||
}
|
||||
let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None);
|
||||
let mut token_ids = vec![30522u32];
|
||||
for i in 0..1000 {
|
||||
let input_ids =
|
||||
Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?;
|
||||
let logits = self.text_decoder.forward(&input_ids, &image_embeds)?;
|
||||
println!("{logits:?}");
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
println!("{token}");
|
||||
token_ids.push(token)
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user