mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add the blip image captioning model (#1140)
* Blip text model. * Blip vision bits. * Blippity. * More blip.
This commit is contained in:
241
candle-transformers/src/models/blip.rs
Normal file
241
candle-transformers/src/models/blip.rs
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::blip_text;
|
||||||
|
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||||
|
use candle::{Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct VisionConfig {
|
||||||
|
hidden_size: usize,
|
||||||
|
intermediate_size: usize,
|
||||||
|
projection_dim: usize,
|
||||||
|
num_hidden_layers: usize,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
image_size: usize,
|
||||||
|
patch_size: usize,
|
||||||
|
hidden_act: candle_nn::Activation,
|
||||||
|
layer_norm_eps: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Config {
|
||||||
|
text_config: blip_text::Config,
|
||||||
|
vision_config: VisionConfig,
|
||||||
|
projection_dim: usize,
|
||||||
|
image_text_hidden_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct VisionEmbeddings {
|
||||||
|
class_embedding: Tensor,
|
||||||
|
patch_embedding: Conv2d,
|
||||||
|
position_embedding: Tensor,
|
||||||
|
num_positions: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionEmbeddings {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let class_embedding = vb.get((1, 1, cfg.hidden_size), "class_embedding")?;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
stride: cfg.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let patch_embedding = conv2d(
|
||||||
|
3,
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("patch_embedding"),
|
||||||
|
)?;
|
||||||
|
let num_patches1 = cfg.image_size / cfg.patch_size;
|
||||||
|
let num_patches = num_patches1 * num_patches1;
|
||||||
|
let num_positions = num_patches + 1;
|
||||||
|
let position_embedding =
|
||||||
|
vb.get((1, num_positions, cfg.hidden_size), "position_embedding")?;
|
||||||
|
Ok(Self {
|
||||||
|
class_embedding,
|
||||||
|
patch_embedding,
|
||||||
|
position_embedding,
|
||||||
|
num_positions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for VisionEmbeddings {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let target_dtype = xs.dtype();
|
||||||
|
let b_size = xs.dim(0)?;
|
||||||
|
let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
|
||||||
|
let d = self.class_embedding.dim(D::Minus1)?;
|
||||||
|
let class_embeds = self
|
||||||
|
.class_embedding
|
||||||
|
.broadcast_as((b_size, 1, d))?
|
||||||
|
.to_dtype(target_dtype)?;
|
||||||
|
let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
|
||||||
|
let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
|
||||||
|
embeddings.broadcast_add(&position_embedding)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
qkv: Linear,
|
||||||
|
projection: Linear,
|
||||||
|
scale: f64,
|
||||||
|
embed_dim: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embed_dim = cfg.hidden_size;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let head_dim = embed_dim / num_heads;
|
||||||
|
let scale = 1f64 / (head_dim as f64).sqrt();
|
||||||
|
let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
|
||||||
|
let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
|
||||||
|
Ok(Self {
|
||||||
|
qkv,
|
||||||
|
projection,
|
||||||
|
scale,
|
||||||
|
embed_dim,
|
||||||
|
head_dim,
|
||||||
|
num_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Attention {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
|
||||||
|
let mixed_qkv = xs
|
||||||
|
.apply(&self.qkv)?
|
||||||
|
.reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
|
||||||
|
.permute((2, 0, 3, 1, 4))?;
|
||||||
|
let query = mixed_qkv.get(0)?;
|
||||||
|
let key = mixed_qkv.get(1)?;
|
||||||
|
let value = mixed_qkv.get(2)?;
|
||||||
|
let attention_scores = query.matmul(&key.t()?)?;
|
||||||
|
let attention_scores = (attention_scores * self.scale)?;
|
||||||
|
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||||
|
attention_probs
|
||||||
|
.matmul(&value)?
|
||||||
|
.permute((0, 2, 1, 3))?
|
||||||
|
.flatten_from(D::Minus2)?
|
||||||
|
.apply(&self.projection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
activation_fn: candle_nn::Activation,
|
||||||
|
fc1: Linear,
|
||||||
|
fc2: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
|
||||||
|
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
activation_fn: cfg.hidden_act,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.fc1)?
|
||||||
|
.apply(&self.activation_fn)?
|
||||||
|
.apply(&self.fc2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct EncoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
layer_norm1: LayerNorm,
|
||||||
|
mlp: MLP,
|
||||||
|
layer_norm2: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderLayer {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embed_dim = cfg.hidden_size;
|
||||||
|
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
|
||||||
|
let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
|
||||||
|
let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
layer_norm1,
|
||||||
|
mlp,
|
||||||
|
layer_norm2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, attention_mask: Tensor) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = xs.apply(&self.layer_norm1)?;
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Encoder {
|
||||||
|
layers: Vec<EncoderLayer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = EncoderLayer::new(cfg, vb.pp(i))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
Ok(Self { layers })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct VisionModel {
|
||||||
|
embeddings: VisionEmbeddings,
|
||||||
|
encoder: Encoder,
|
||||||
|
post_layernorm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VisionModel {
|
||||||
|
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||||
|
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||||
|
let post_layernorm =
|
||||||
|
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
post_layernorm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct BlipForConditionalGeneration {
|
||||||
|
vision_model: VisionModel,
|
||||||
|
text_decoder: blip_text::TextLMHeadModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BlipForConditionalGeneration {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
|
||||||
|
let text_decoder =
|
||||||
|
blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
|
||||||
|
Ok(Self {
|
||||||
|
vision_model,
|
||||||
|
text_decoder,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
350
candle-transformers/src/models/blip_text.rs
Normal file
350
candle-transformers/src/models/blip_text.rs
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||||
|
use candle::{Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
vocab_size: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
encoder_hidden_size: usize,
|
||||||
|
intermediate_size: usize,
|
||||||
|
projection_dim: usize,
|
||||||
|
num_hidden_layers: usize,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
max_position_embeddings: usize,
|
||||||
|
hidden_act: candle_nn::Activation,
|
||||||
|
layer_norm_eps: f64,
|
||||||
|
is_decoder: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextEmbeddings {
|
||||||
|
word_embedddings: Embedding,
|
||||||
|
position_embeddings: Embedding,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
position_ids: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextEmbeddings {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let word_embedddings =
|
||||||
|
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
|
||||||
|
let position_embeddings = Embedding::new(
|
||||||
|
cfg.max_position_embeddings,
|
||||||
|
cfg.hidden_size,
|
||||||
|
vb.pp("position_embeddings"),
|
||||||
|
)?;
|
||||||
|
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||||
|
let position_ids =
|
||||||
|
Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
|
||||||
|
Ok(Self {
|
||||||
|
word_embedddings,
|
||||||
|
position_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
position_ids,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextSelfAttention {
|
||||||
|
query: Linear,
|
||||||
|
key: Linear,
|
||||||
|
value: Linear,
|
||||||
|
all_head_size: usize,
|
||||||
|
attention_head_size: usize,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextSelfAttention {
|
||||||
|
fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let num_attention_heads = cfg.num_attention_heads;
|
||||||
|
let attention_head_size = cfg.hidden_size / num_attention_heads;
|
||||||
|
let all_head_size = cfg.num_attention_heads * attention_head_size;
|
||||||
|
let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
|
||||||
|
let in_size = if is_cross_attention {
|
||||||
|
cfg.encoder_hidden_size
|
||||||
|
} else {
|
||||||
|
cfg.hidden_size
|
||||||
|
};
|
||||||
|
let key = linear(in_size, all_head_size, vb.pp("key"))?;
|
||||||
|
let value = linear(in_size, all_head_size, vb.pp("value"))?;
|
||||||
|
Ok(Self {
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
all_head_size,
|
||||||
|
attention_head_size,
|
||||||
|
num_attention_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len, _) = xs.dims3()?;
|
||||||
|
xs.reshape((
|
||||||
|
b_size,
|
||||||
|
seq_len,
|
||||||
|
self.num_attention_heads,
|
||||||
|
self.attention_head_size,
|
||||||
|
))?
|
||||||
|
.permute((0, 2, 1, 3))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextSelfOutput {
|
||||||
|
dense: Linear,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextSelfOutput {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
|
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||||
|
Ok(Self { dense, layer_norm })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
(xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextAttention {
|
||||||
|
self_: TextSelfAttention,
|
||||||
|
output: TextSelfOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextAttention {
|
||||||
|
fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
|
||||||
|
let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
|
||||||
|
Ok(Self { self_, output })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextIntermediate {
|
||||||
|
dense: Linear,
|
||||||
|
intermediate_act_fn: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextIntermediate {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
intermediate_act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextIntermediate {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextOutput {
|
||||||
|
dense: Linear,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextOutput {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
|
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||||
|
Ok(Self { dense, layer_norm })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
|
(xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextLayer {
|
||||||
|
attention: TextAttention,
|
||||||
|
cross_attention: Option<TextAttention>,
|
||||||
|
intermediate: TextIntermediate,
|
||||||
|
output: TextOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextLayer {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
|
||||||
|
let cross_attention = if cfg.is_decoder {
|
||||||
|
Some(TextAttention::new(cfg, true, vb.pp("attention"))?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
|
||||||
|
let output = TextOutput::new(cfg, vb.pp("output"))?;
|
||||||
|
Ok(Self {
|
||||||
|
attention,
|
||||||
|
cross_attention,
|
||||||
|
intermediate,
|
||||||
|
output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextLayer {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextEncoder {
|
||||||
|
layers: Vec<TextLayer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextEncoder {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb = vb.pp("layer");
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = TextLayer::new(cfg, vb.pp(i))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
Ok(Self { layers })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextEncoder {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = xs.apply(layer)?
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextPooler {
|
||||||
|
dense: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextPooler {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
|
Ok(Self { dense })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextPooler {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.narrow(D::Minus1, 0, 1)?
|
||||||
|
.squeeze(D::Minus1)?
|
||||||
|
.apply(&self.dense)?
|
||||||
|
.tanh()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextPredictionHeadTransform {
|
||||||
|
dense: Linear,
|
||||||
|
transform_act_fn: candle_nn::Activation,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextPredictionHeadTransform {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
|
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
transform_act_fn: cfg.hidden_act,
|
||||||
|
layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextPredictionHeadTransform {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.dense)?
|
||||||
|
.apply(&self.transform_act_fn)?
|
||||||
|
.apply(&self.layer_norm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextLMPredictionHead {
|
||||||
|
transform: TextPredictionHeadTransform,
|
||||||
|
decoder: Linear,
|
||||||
|
bias: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextLMPredictionHead {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
|
||||||
|
let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
|
||||||
|
let bias = vb.get(cfg.vocab_size, "bias")?;
|
||||||
|
Ok(Self {
|
||||||
|
transform,
|
||||||
|
decoder,
|
||||||
|
bias,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextLMPredictionHead {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.transform)?.apply(&self.decoder)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextOnlyMLMHead {
|
||||||
|
predictions: TextLMPredictionHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextOnlyMLMHead {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
|
||||||
|
Ok(Self { predictions })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for TextOnlyMLMHead {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.predictions.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct TextModel {
|
||||||
|
embeddings: TextEmbeddings,
|
||||||
|
encoder: TextEncoder,
|
||||||
|
pooler: Option<TextPooler>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextModel {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||||
|
let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
pooler: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TextLMHeadModel {
|
||||||
|
bert: TextModel,
|
||||||
|
cls: TextOnlyMLMHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextLMHeadModel {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let bert = TextModel::new(cfg, vb.pp("bert"))?;
|
||||||
|
let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
|
||||||
|
Ok(Self { bert, cls })
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,7 @@
|
|||||||
pub mod bert;
|
pub mod bert;
|
||||||
pub mod bigcode;
|
pub mod bigcode;
|
||||||
|
pub mod blip;
|
||||||
|
pub mod blip_text;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
|
@ -58,8 +58,8 @@ pub struct Conv2d {
|
|||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Conv2d {
|
impl Module for Conv2d {
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
self.inner.forward(x)
|
self.inner.forward(x)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user