mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Change/bert encoder public (#2658)
* change: BertEncoder struct to public * change: make certain fields in Config struct public * change: all fields in bert config struct to be public * change: add clone to bert encoder and others * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -22,6 +22,7 @@ pub enum HiddenAct {
|
||||
Relu,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HiddenActLayer {
|
||||
act: HiddenAct,
|
||||
span: tracing::Span,
|
||||
@ -46,7 +47,7 @@ impl HiddenActLayer {
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
pub enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
@ -54,24 +55,24 @@ enum PositionEmbeddingType {
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: HiddenAct,
|
||||
hidden_dropout_prob: f64,
|
||||
max_position_embeddings: usize,
|
||||
type_vocab_size: usize,
|
||||
initializer_range: f64,
|
||||
layer_norm_eps: f64,
|
||||
pad_token_id: usize,
|
||||
pub hidden_dropout_prob: f64,
|
||||
pub max_position_embeddings: usize,
|
||||
pub type_vocab_size: usize,
|
||||
pub initializer_range: f64,
|
||||
pub layer_norm_eps: f64,
|
||||
pub pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
pub position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
use_cache: bool,
|
||||
classifier_dropout: Option<f64>,
|
||||
model_type: Option<String>,
|
||||
pub use_cache: bool,
|
||||
pub classifier_dropout: Option<f64>,
|
||||
pub model_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@ -121,6 +122,7 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
@ -199,6 +201,7 @@ impl BertEmbeddings {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BertSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
@ -266,6 +269,7 @@ impl BertSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BertSelfOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
@ -299,6 +303,7 @@ impl BertSelfOutput {
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
|
||||
#[derive(Clone)]
|
||||
struct BertAttention {
|
||||
self_attention: BertSelfAttention,
|
||||
self_output: BertSelfOutput,
|
||||
@ -325,6 +330,7 @@ impl BertAttention {
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
|
||||
#[derive(Clone)]
|
||||
struct BertIntermediate {
|
||||
dense: Linear,
|
||||
intermediate_act: HiddenActLayer,
|
||||
@ -352,6 +358,7 @@ impl Module for BertIntermediate {
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
|
||||
#[derive(Clone)]
|
||||
struct BertOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
@ -385,7 +392,8 @@ impl BertOutput {
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
|
||||
struct BertLayer {
|
||||
#[derive(Clone)]
|
||||
pub struct BertLayer {
|
||||
attention: BertAttention,
|
||||
intermediate: BertIntermediate,
|
||||
output: BertOutput,
|
||||
@ -420,13 +428,14 @@ impl BertLayer {
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||
struct BertEncoder {
|
||||
layers: Vec<BertLayer>,
|
||||
#[derive(Clone)]
|
||||
pub struct BertEncoder {
|
||||
pub layers: Vec<BertLayer>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
@ -434,7 +443,7 @@ impl BertEncoder {
|
||||
Ok(BertEncoder { layers, span })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
|
Reference in New Issue
Block a user