Remove the unused pragma and properly apply the bias. (#1147)

This commit is contained in:
Laurent Mazare
2023-10-22 06:47:40 +01:00
committed by GitHub
parent 3115fe42e4
commit 5b32c2a41e
3 changed files with 15 additions and 22 deletions

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use super::blip_text; use super::blip_text;
use super::with_tracing::{conv2d, linear, Conv2d, Linear}; use super::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, Result, Tensor, D}; use candle::{Module, Result, Tensor, D};
@ -65,7 +64,6 @@ struct VisionEmbeddings {
class_embedding: Tensor, class_embedding: Tensor,
patch_embedding: Conv2d, patch_embedding: Conv2d,
position_embedding: Tensor, position_embedding: Tensor,
num_positions: usize,
} }
impl VisionEmbeddings { impl VisionEmbeddings {
@ -91,7 +89,6 @@ impl VisionEmbeddings {
class_embedding, class_embedding,
patch_embedding, patch_embedding,
position_embedding, position_embedding,
num_positions,
}) })
} }
} }
@ -117,8 +114,6 @@ struct Attention {
qkv: Linear, qkv: Linear,
projection: Linear, projection: Linear,
scale: f64, scale: f64,
embed_dim: usize,
head_dim: usize,
num_heads: usize, num_heads: usize,
} }
@ -134,8 +129,6 @@ impl Attention {
qkv, qkv,
projection, projection,
scale, scale,
embed_dim,
head_dim,
num_heads, num_heads,
}) })
} }

View File

@ -1,5 +1,4 @@
#![allow(unused)] use super::with_tracing::{linear, Embedding, Linear};
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use candle::{Module, Result, Tensor, D}; use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder}; use candle_nn::{layer_norm, LayerNorm, VarBuilder};
@ -63,7 +62,6 @@ struct TextSelfAttention {
query: Linear, query: Linear,
key: Linear, key: Linear,
value: Linear, value: Linear,
all_head_size: usize,
attention_head_size: usize, attention_head_size: usize,
num_attention_heads: usize, num_attention_heads: usize,
attention_scale: f64, attention_scale: f64,
@ -87,7 +85,6 @@ impl TextSelfAttention {
query, query,
key, key,
value, value,
all_head_size,
attention_head_size, attention_head_size,
num_attention_heads, num_attention_heads,
attention_scale, attention_scale,
@ -301,12 +298,12 @@ impl TextEncoder {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct TextPooler { pub struct TextPooler {
dense: Linear, dense: Linear,
} }
impl TextPooler { impl TextPooler {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
Ok(Self { dense }) Ok(Self { dense })
} }
@ -352,19 +349,15 @@ impl Module for TextPredictionHeadTransform {
struct TextLMPredictionHead { struct TextLMPredictionHead {
transform: TextPredictionHeadTransform, transform: TextPredictionHeadTransform,
decoder: Linear, decoder: Linear,
bias: Tensor,
} }
impl TextLMPredictionHead { impl TextLMPredictionHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?; let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?; let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
let bias = vb.get(cfg.vocab_size, "bias")?; let bias = vb.get(cfg.vocab_size, "bias")?;
Ok(Self { let decoder = Linear::from_weights(weight, Some(bias));
transform, Ok(Self { transform, decoder })
decoder,
bias,
})
} }
} }
@ -396,7 +389,7 @@ impl Module for TextOnlyMLMHead {
struct TextModel { struct TextModel {
embeddings: TextEmbeddings, embeddings: TextEmbeddings,
encoder: TextEncoder, encoder: TextEncoder,
pooler: Option<TextPooler>, // We do not need the pooler for caption generation
} }
impl TextModel { impl TextModel {
@ -406,7 +399,6 @@ impl TextModel {
Ok(Self { Ok(Self {
embeddings, embeddings,
encoder, encoder,
pooler: None,
}) })
} }

View File

@ -32,6 +32,14 @@ pub struct Linear {
span: tracing::Span, span: tracing::Span,
} }
impl Linear {
pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self {
let inner = candle_nn::Linear::new(weights, bias);
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { inner, span }
}
}
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> { pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?; let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear"); let span = tracing::span!(tracing::Level::TRACE, "linear");