From d796945ad885fdea57c8bd0bc3281bec4a12b17e Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 13:04:41 +0100 Subject: [PATCH] Add more to the forward pass. --- candle-examples/examples/bert/main.rs | 57 ++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 68864a7f..fce7f0f1 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -131,12 +131,28 @@ impl LayerNorm { // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 struct BertEmbeddings { word_embeddings: Embedding, - position_embeddings: Embedding, + position_embeddings: Option, token_type_embeddings: Embedding, + layer_norm: LayerNorm, + dropout: Dropout, position_ids: Tensor, token_type_ids: Tensor, } +impl BertEmbeddings { + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; + let mut embeddings = (input_embeddings + token_type_embeddings)?; + if let Some(position_embeddings) = &self.position_embeddings { + embeddings = (&embeddings + position_embeddings.forward(&embeddings))? + } + let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = self.dropout.forward(&embeddings)?; + Ok(embeddings) + } +} + struct BertSelfAttention { query: Linear, key: Linear, @@ -144,6 +160,12 @@ struct BertSelfAttention { dropout: Dropout, } +impl BertSelfAttention { + fn forward(&self, _xs: &Tensor) -> Result { + todo!() + } +} + struct BertSelfOutput { dense: Linear, layer_norm: LayerNorm, @@ -165,8 +187,10 @@ struct BertAttention { } impl BertAttention { - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn forward(&self, hidden_states: &Tensor) -> Result { + let self_outputs = self.self_attention.forward(hidden_states)?; + let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; + Ok(attention_output) } } @@ -206,8 +230,16 @@ struct BertLayer { } impl BertLayer { - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn forward(&self, hidden_states: &Tensor) -> Result { + let attention_output = self.attention.forward(hidden_states)?; + // TODO: Support cross-attention? + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + // TODO: Support something similar to `apply_chunking_to_forward`? + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok(layer_output) } } @@ -217,8 +249,13 @@ struct BertEncoder { } impl BertEncoder { - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn forward(&self, hidden_states: &Tensor) -> Result { + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states)? + } + Ok(hidden_states) } } @@ -229,8 +266,10 @@ struct BertModel { } impl BertModel { - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result { + let embedding_output = self.embeddings.forward(input_ids, position_ids)?; + let sequence_output = self.encoder.forward(&embedding_output)?; + Ok(sequence_output) } }