#![allow(dead_code)] use anyhow::Result as R; use candle::{Result, Tensor}; #[derive(Debug, Clone, PartialEq, Eq)] enum HiddenAct { Gelu, Relu, } impl HiddenAct { fn forward(&self, xs: &Tensor) -> Result { match self { Self::Gelu => xs.gelu(), Self::Relu => xs.relu(), } } } #[derive(Debug, Clone, PartialEq, Eq)] enum PositionEmbeddingType { Absolute, } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 #[derive(Debug, Clone, PartialEq)] struct Config { vocab_size: usize, hidden_size: usize, num_hidden_layers: usize, num_attention_heads: usize, intermediate_size: usize, hidden_act: HiddenAct, hidden_dropout_prob: f64, max_position_embeddings: usize, type_vocab_size: usize, initializer_range: f64, layer_norm_eps: f64, pad_token_id: usize, position_embedding_type: PositionEmbeddingType, use_cache: bool, classifier_dropout: Option, } impl Default for Config { fn default() -> Self { Self { vocab_size: 30522, hidden_size: 768, num_hidden_layers: 12, num_attention_heads: 12, intermediate_size: 3072, hidden_act: HiddenAct::Gelu, hidden_dropout_prob: 0.1, max_position_embeddings: 512, type_vocab_size: 2, initializer_range: 0.02, layer_norm_eps: 1e-12, pad_token_id: 0, position_embedding_type: PositionEmbeddingType::Absolute, use_cache: true, classifier_dropout: None, } } } struct Embedding { embeddings: Tensor, } impl Embedding { fn new(embeddings: Tensor) -> Self { Self { embeddings } } fn forward(&self, indexes: &Tensor) -> Result { Tensor::embedding(indexes, &self.embeddings) } } struct Linear { weight: Tensor, } impl Linear { fn new(weight: Tensor) -> Self { Self { weight } } fn forward(&self, x: &Tensor) -> Result { let x = x.matmul(&self.weight.t()?)?; Ok(x) } } struct Dropout {} impl Dropout { fn new() -> Self { Self {} } fn forward(&self, x: &Tensor) -> Result { // TODO Ok(x.clone()) } } struct LayerNorm { scale: Tensor, } impl LayerNorm { fn new(scale: Tensor) -> Self { Self { scale } } fn forward(&self, x: &Tensor) -> Result { let (seq_len, hidden_size) = x.shape().r2()?; let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let size = self.scale.shape().r1()?; let scale = self.scale.broadcast_as((seq_len, size))?; let x = (scale * x_normed)?; Ok(x) } } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 struct BertEmbeddings { word_embeddings: Embedding, position_embeddings: Embedding, token_type_embeddings: Embedding, position_ids: Tensor, token_type_ids: Tensor, } struct BertSelfAttention { query: Linear, key: Linear, value: Linear, dropout: Dropout, } struct BertSelfOutput { dense: Linear, layer_norm: LayerNorm, dropout: Dropout, } impl BertSelfOutput { fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?; self.layer_norm.forward(&(hidden_states + input_tensor)?) } } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 struct BertAttention { self_attention: BertSelfAttention, self_output: BertSelfOutput, } impl BertAttention { fn forward(&self, _xs: &Tensor) -> Result { todo!() } } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 struct BertIntermediate { dense: Linear, intermediate_act: HiddenAct, } impl BertIntermediate { fn forward(&self, hidden_states: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; self.intermediate_act.forward(&hidden_states) } } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 struct BertOutput { dense: Linear, layer_norm: LayerNorm, dropout: Dropout, } impl BertOutput { fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?; self.layer_norm.forward(&(hidden_states + input_tensor)?) } } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 struct BertLayer { attention: BertAttention, intermediate: BertIntermediate, output: BertOutput, } impl BertLayer { fn forward(&self, _xs: &Tensor) -> Result { todo!() } } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 struct BertEncoder { layers: Vec, } impl BertEncoder { fn forward(&self, _xs: &Tensor) -> Result { todo!() } } // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 struct BertModel { embeddings: BertEmbeddings, encoder: BertEncoder, } impl BertModel { fn forward(&self, _xs: &Tensor) -> Result { todo!() } } fn main() -> R<()> { Ok(()) }