Change/bert encoder public (#2658)

* change: BertEncoder struct to public

* change: make certain fields in Config struct public

* change: all fields in bert config struct to be public

* change: add clone to bert encoder and others

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Justin Sing
2024-12-04 15:22:30 -05:00
committed by GitHub
parent 145aa7193c
commit 1807be84f4

View File

@ -22,6 +22,7 @@ pub enum HiddenAct {
Relu, Relu,
} }
#[derive(Clone)]
struct HiddenActLayer { struct HiddenActLayer {
act: HiddenAct, act: HiddenAct,
span: tracing::Span, span: tracing::Span,
@ -46,7 +47,7 @@ impl HiddenActLayer {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
enum PositionEmbeddingType { pub enum PositionEmbeddingType {
#[default] #[default]
Absolute, Absolute,
} }
@ -54,24 +55,24 @@ enum PositionEmbeddingType {
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
#[derive(Debug, Clone, PartialEq, Deserialize)] #[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config { pub struct Config {
vocab_size: usize, pub vocab_size: usize,
hidden_size: usize, pub hidden_size: usize,
num_hidden_layers: usize, pub num_hidden_layers: usize,
num_attention_heads: usize, pub num_attention_heads: usize,
intermediate_size: usize, pub intermediate_size: usize,
pub hidden_act: HiddenAct, pub hidden_act: HiddenAct,
hidden_dropout_prob: f64, pub hidden_dropout_prob: f64,
max_position_embeddings: usize, pub max_position_embeddings: usize,
type_vocab_size: usize, pub type_vocab_size: usize,
initializer_range: f64, pub initializer_range: f64,
layer_norm_eps: f64, pub layer_norm_eps: f64,
pad_token_id: usize, pub pad_token_id: usize,
#[serde(default)] #[serde(default)]
position_embedding_type: PositionEmbeddingType, pub position_embedding_type: PositionEmbeddingType,
#[serde(default)] #[serde(default)]
use_cache: bool, pub use_cache: bool,
classifier_dropout: Option<f64>, pub classifier_dropout: Option<f64>,
model_type: Option<String>, pub model_type: Option<String>,
} }
impl Default for Config { impl Default for Config {
@ -121,6 +122,7 @@ impl Config {
} }
} }
#[derive(Clone)]
struct Dropout { struct Dropout {
#[allow(dead_code)] #[allow(dead_code)]
pr: f64, pr: f64,
@ -199,6 +201,7 @@ impl BertEmbeddings {
} }
} }
#[derive(Clone)]
struct BertSelfAttention { struct BertSelfAttention {
query: Linear, query: Linear,
key: Linear, key: Linear,
@ -266,6 +269,7 @@ impl BertSelfAttention {
} }
} }
#[derive(Clone)]
struct BertSelfOutput { struct BertSelfOutput {
dense: Linear, dense: Linear,
layer_norm: LayerNorm, layer_norm: LayerNorm,
@ -299,6 +303,7 @@ impl BertSelfOutput {
} }
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
#[derive(Clone)]
struct BertAttention { struct BertAttention {
self_attention: BertSelfAttention, self_attention: BertSelfAttention,
self_output: BertSelfOutput, self_output: BertSelfOutput,
@ -325,6 +330,7 @@ impl BertAttention {
} }
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
#[derive(Clone)]
struct BertIntermediate { struct BertIntermediate {
dense: Linear, dense: Linear,
intermediate_act: HiddenActLayer, intermediate_act: HiddenActLayer,
@ -352,6 +358,7 @@ impl Module for BertIntermediate {
} }
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
#[derive(Clone)]
struct BertOutput { struct BertOutput {
dense: Linear, dense: Linear,
layer_norm: LayerNorm, layer_norm: LayerNorm,
@ -385,7 +392,8 @@ impl BertOutput {
} }
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
struct BertLayer { #[derive(Clone)]
pub struct BertLayer {
attention: BertAttention, attention: BertAttention,
intermediate: BertIntermediate, intermediate: BertIntermediate,
output: BertOutput, output: BertOutput,
@ -420,13 +428,14 @@ impl BertLayer {
} }
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
struct BertEncoder { #[derive(Clone)]
layers: Vec<BertLayer>, pub struct BertEncoder {
pub layers: Vec<BertLayer>,
span: tracing::Span, span: tracing::Span,
} }
impl BertEncoder { impl BertEncoder {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers) let layers = (0..config.num_hidden_layers)
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config)) .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -434,7 +443,7 @@ impl BertEncoder {
Ok(BertEncoder { layers, span }) Ok(BertEncoder { layers, span })
} }
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone(); let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/... // Use a loop rather than a fold as it's easier to modify when adding debug/...