mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Remove the unused pragma and properly apply the bias. (#1147)
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
#![allow(unused)]
|
||||
use super::blip_text;
|
||||
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
@ -65,7 +64,6 @@ struct VisionEmbeddings {
|
||||
class_embedding: Tensor,
|
||||
patch_embedding: Conv2d,
|
||||
position_embedding: Tensor,
|
||||
num_positions: usize,
|
||||
}
|
||||
|
||||
impl VisionEmbeddings {
|
||||
@ -91,7 +89,6 @@ impl VisionEmbeddings {
|
||||
class_embedding,
|
||||
patch_embedding,
|
||||
position_embedding,
|
||||
num_positions,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -117,8 +114,6 @@ struct Attention {
|
||||
qkv: Linear,
|
||||
projection: Linear,
|
||||
scale: f64,
|
||||
embed_dim: usize,
|
||||
head_dim: usize,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
@ -134,8 +129,6 @@ impl Attention {
|
||||
qkv,
|
||||
projection,
|
||||
scale,
|
||||
embed_dim,
|
||||
head_dim,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
#![allow(unused)]
|
||||
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||
use super::with_tracing::{linear, Embedding, Linear};
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
@ -63,7 +62,6 @@ struct TextSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
all_head_size: usize,
|
||||
attention_head_size: usize,
|
||||
num_attention_heads: usize,
|
||||
attention_scale: f64,
|
||||
@ -87,7 +85,6 @@ impl TextSelfAttention {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
all_head_size,
|
||||
attention_head_size,
|
||||
num_attention_heads,
|
||||
attention_scale,
|
||||
@ -301,12 +298,12 @@ impl TextEncoder {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextPooler {
|
||||
pub struct TextPooler {
|
||||
dense: Linear,
|
||||
}
|
||||
|
||||
impl TextPooler {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
Ok(Self { dense })
|
||||
}
|
||||
@ -352,19 +349,15 @@ impl Module for TextPredictionHeadTransform {
|
||||
struct TextLMPredictionHead {
|
||||
transform: TextPredictionHeadTransform,
|
||||
decoder: Linear,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl TextLMPredictionHead {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
|
||||
let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
|
||||
let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
|
||||
let bias = vb.get(cfg.vocab_size, "bias")?;
|
||||
Ok(Self {
|
||||
transform,
|
||||
decoder,
|
||||
bias,
|
||||
})
|
||||
let decoder = Linear::from_weights(weight, Some(bias));
|
||||
Ok(Self { transform, decoder })
|
||||
}
|
||||
}
|
||||
|
||||
@ -396,7 +389,7 @@ impl Module for TextOnlyMLMHead {
|
||||
struct TextModel {
|
||||
embeddings: TextEmbeddings,
|
||||
encoder: TextEncoder,
|
||||
pooler: Option<TextPooler>,
|
||||
// We do not need the pooler for caption generation
|
||||
}
|
||||
|
||||
impl TextModel {
|
||||
@ -406,7 +399,6 @@ impl TextModel {
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
pooler: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,14 @@ pub struct Linear {
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self {
|
||||
let inner = candle_nn::Linear::new(weights, bias);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Self { inner, span }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let inner = candle_nn::linear(d1, d2, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
|
Reference in New Issue
Block a user