From 2309c5fac5028c15738fff47e514593d382db906 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 12:17:06 +0100 Subject: [PATCH] Boilerplate code for Bert. --- candle-examples/examples/bert/main.rs | 239 ++++++++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 candle-examples/examples/bert/main.rs diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs new file mode 100644 index 00000000..68864a7f --- /dev/null +++ b/candle-examples/examples/bert/main.rs @@ -0,0 +1,239 @@ +#![allow(dead_code)] + +use anyhow::Result as R; +use candle::{Result, Tensor}; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum HiddenAct { + Gelu, + Relu, +} + +impl HiddenAct { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Gelu => xs.gelu(), + Self::Relu => xs.relu(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum PositionEmbeddingType { + Absolute, +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 +#[derive(Debug, Clone, PartialEq)] +struct Config { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + intermediate_size: usize, + hidden_act: HiddenAct, + hidden_dropout_prob: f64, + max_position_embeddings: usize, + type_vocab_size: usize, + initializer_range: f64, + layer_norm_eps: f64, + pad_token_id: usize, + position_embedding_type: PositionEmbeddingType, + use_cache: bool, + classifier_dropout: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 30522, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + } + } +} + +struct Embedding { + embeddings: Tensor, +} + +impl Embedding { + fn new(embeddings: Tensor) -> Self { + Self { embeddings } + } + + fn forward(&self, indexes: &Tensor) -> Result { + Tensor::embedding(indexes, &self.embeddings) + } +} + +struct Linear { + weight: Tensor, +} + +impl Linear { + fn new(weight: Tensor) -> Self { + Self { weight } + } + + fn forward(&self, x: &Tensor) -> Result { + let x = x.matmul(&self.weight.t()?)?; + Ok(x) + } +} + +struct Dropout {} + +impl Dropout { + fn new() -> Self { + Self {} + } + + fn forward(&self, x: &Tensor) -> Result { + // TODO + Ok(x.clone()) + } +} + +struct LayerNorm { + scale: Tensor, +} + +impl LayerNorm { + fn new(scale: Tensor) -> Self { + Self { scale } + } + + fn forward(&self, x: &Tensor) -> Result { + let (seq_len, hidden_size) = x.shape().r2()?; + let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; + let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; + let size = self.scale.shape().r1()?; + let scale = self.scale.broadcast_as((seq_len, size))?; + let x = (scale * x_normed)?; + Ok(x) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 +struct BertEmbeddings { + word_embeddings: Embedding, + position_embeddings: Embedding, + token_type_embeddings: Embedding, + position_ids: Tensor, + token_type_ids: Tensor, +} + +struct BertSelfAttention { + query: Linear, + key: Linear, + value: Linear, + dropout: Dropout, +} + +struct BertSelfOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, +} + +impl BertSelfOutput { + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 +struct BertAttention { + self_attention: BertSelfAttention, + self_output: BertSelfOutput, +} + +impl BertAttention { + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 +struct BertIntermediate { + dense: Linear, + intermediate_act: HiddenAct, +} + +impl BertIntermediate { + fn forward(&self, hidden_states: &Tensor) -> Result { + let hidden_states = self.dense.forward(hidden_states)?; + self.intermediate_act.forward(&hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 +struct BertOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, +} + +impl BertOutput { + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 +struct BertLayer { + attention: BertAttention, + intermediate: BertIntermediate, + output: BertOutput, +} + +impl BertLayer { + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +struct BertEncoder { + layers: Vec, +} + +impl BertEncoder { + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 +struct BertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, +} + +impl BertModel { + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } +} + +fn main() -> R<()> { + Ok(()) +}