mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Remove the unused pragma and properly apply the bias. (#1147)
This commit is contained in:
@ -1,4 +1,3 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use super::blip_text;
|
use super::blip_text;
|
||||||
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
|
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||||
use candle::{Module, Result, Tensor, D};
|
use candle::{Module, Result, Tensor, D};
|
||||||
@ -65,7 +64,6 @@ struct VisionEmbeddings {
|
|||||||
class_embedding: Tensor,
|
class_embedding: Tensor,
|
||||||
patch_embedding: Conv2d,
|
patch_embedding: Conv2d,
|
||||||
position_embedding: Tensor,
|
position_embedding: Tensor,
|
||||||
num_positions: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VisionEmbeddings {
|
impl VisionEmbeddings {
|
||||||
@ -91,7 +89,6 @@ impl VisionEmbeddings {
|
|||||||
class_embedding,
|
class_embedding,
|
||||||
patch_embedding,
|
patch_embedding,
|
||||||
position_embedding,
|
position_embedding,
|
||||||
num_positions,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,8 +114,6 @@ struct Attention {
|
|||||||
qkv: Linear,
|
qkv: Linear,
|
||||||
projection: Linear,
|
projection: Linear,
|
||||||
scale: f64,
|
scale: f64,
|
||||||
embed_dim: usize,
|
|
||||||
head_dim: usize,
|
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,8 +129,6 @@ impl Attention {
|
|||||||
qkv,
|
qkv,
|
||||||
projection,
|
projection,
|
||||||
scale,
|
scale,
|
||||||
embed_dim,
|
|
||||||
head_dim,
|
|
||||||
num_heads,
|
num_heads,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
#![allow(unused)]
|
use super::with_tracing::{linear, Embedding, Linear};
|
||||||
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
|
||||||
use candle::{Module, Result, Tensor, D};
|
use candle::{Module, Result, Tensor, D};
|
||||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||||
|
|
||||||
@ -63,7 +62,6 @@ struct TextSelfAttention {
|
|||||||
query: Linear,
|
query: Linear,
|
||||||
key: Linear,
|
key: Linear,
|
||||||
value: Linear,
|
value: Linear,
|
||||||
all_head_size: usize,
|
|
||||||
attention_head_size: usize,
|
attention_head_size: usize,
|
||||||
num_attention_heads: usize,
|
num_attention_heads: usize,
|
||||||
attention_scale: f64,
|
attention_scale: f64,
|
||||||
@ -87,7 +85,6 @@ impl TextSelfAttention {
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
all_head_size,
|
|
||||||
attention_head_size,
|
attention_head_size,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_scale,
|
attention_scale,
|
||||||
@ -301,12 +298,12 @@ impl TextEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct TextPooler {
|
pub struct TextPooler {
|
||||||
dense: Linear,
|
dense: Linear,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextPooler {
|
impl TextPooler {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||||
Ok(Self { dense })
|
Ok(Self { dense })
|
||||||
}
|
}
|
||||||
@ -352,19 +349,15 @@ impl Module for TextPredictionHeadTransform {
|
|||||||
struct TextLMPredictionHead {
|
struct TextLMPredictionHead {
|
||||||
transform: TextPredictionHeadTransform,
|
transform: TextPredictionHeadTransform,
|
||||||
decoder: Linear,
|
decoder: Linear,
|
||||||
bias: Tensor,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextLMPredictionHead {
|
impl TextLMPredictionHead {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
|
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
|
||||||
let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
|
let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
|
||||||
let bias = vb.get(cfg.vocab_size, "bias")?;
|
let bias = vb.get(cfg.vocab_size, "bias")?;
|
||||||
Ok(Self {
|
let decoder = Linear::from_weights(weight, Some(bias));
|
||||||
transform,
|
Ok(Self { transform, decoder })
|
||||||
decoder,
|
|
||||||
bias,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -396,7 +389,7 @@ impl Module for TextOnlyMLMHead {
|
|||||||
struct TextModel {
|
struct TextModel {
|
||||||
embeddings: TextEmbeddings,
|
embeddings: TextEmbeddings,
|
||||||
encoder: TextEncoder,
|
encoder: TextEncoder,
|
||||||
pooler: Option<TextPooler>,
|
// We do not need the pooler for caption generation
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextModel {
|
impl TextModel {
|
||||||
@ -406,7 +399,6 @@ impl TextModel {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
embeddings,
|
embeddings,
|
||||||
encoder,
|
encoder,
|
||||||
pooler: None,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,6 +32,14 @@ pub struct Linear {
|
|||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self {
|
||||||
|
let inner = candle_nn::Linear::new(weights, bias);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Self { inner, span }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let inner = candle_nn::linear(d1, d2, vb)?;
|
let inner = candle_nn::linear(d1, d2, vb)?;
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Reference in New Issue
Block a user