Remove the unused pragma and properly apply the bias. (#1147)

This commit is contained in:
Laurent Mazare
2023-10-22 06:47:40 +01:00
committed by GitHub
parent 3115fe42e4
commit 5b32c2a41e
3 changed files with 15 additions and 22 deletions

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use super::blip_text;
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, Result, Tensor, D};
@ -65,7 +64,6 @@ struct VisionEmbeddings {
class_embedding: Tensor,
patch_embedding: Conv2d,
position_embedding: Tensor,
num_positions: usize,
}
impl VisionEmbeddings {
@ -91,7 +89,6 @@ impl VisionEmbeddings {
class_embedding,
patch_embedding,
position_embedding,
num_positions,
})
}
}
@ -117,8 +114,6 @@ struct Attention {
qkv: Linear,
projection: Linear,
scale: f64,
embed_dim: usize,
head_dim: usize,
num_heads: usize,
}
@ -134,8 +129,6 @@ impl Attention {
qkv,
projection,
scale,
embed_dim,
head_dim,
num_heads,
})
}

View File

@ -1,5 +1,4 @@
#![allow(unused)]
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
@ -63,7 +62,6 @@ struct TextSelfAttention {
query: Linear,
key: Linear,
value: Linear,
all_head_size: usize,
attention_head_size: usize,
num_attention_heads: usize,
attention_scale: f64,
@ -87,7 +85,6 @@ impl TextSelfAttention {
query,
key,
value,
all_head_size,
attention_head_size,
num_attention_heads,
attention_scale,
@ -301,12 +298,12 @@ impl TextEncoder {
}
#[derive(Debug, Clone)]
struct TextPooler {
pub struct TextPooler {
dense: Linear,
}
impl TextPooler {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
Ok(Self { dense })
}
@ -352,19 +349,15 @@ impl Module for TextPredictionHeadTransform {
struct TextLMPredictionHead {
transform: TextPredictionHeadTransform,
decoder: Linear,
bias: Tensor,
}
impl TextLMPredictionHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
let bias = vb.get(cfg.vocab_size, "bias")?;
Ok(Self {
transform,
decoder,
bias,
})
let decoder = Linear::from_weights(weight, Some(bias));
Ok(Self { transform, decoder })
}
}
@ -396,7 +389,7 @@ impl Module for TextOnlyMLMHead {
struct TextModel {
embeddings: TextEmbeddings,
encoder: TextEncoder,
pooler: Option<TextPooler>,
// We do not need the pooler for caption generation
}
impl TextModel {
@ -406,7 +399,6 @@ impl TextModel {
Ok(Self {
embeddings,
encoder,
pooler: None,
})
}

View File

@ -32,6 +32,14 @@ pub struct Linear {
span: tracing::Span,
}
impl Linear {
pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self {
let inner = candle_nn::Linear::new(weights, bias);
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { inner, span }
}
}
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");