Add more to the forward pass.

This commit is contained in:
laurent
2023-07-03 13:04:41 +01:00
parent 2309c5fac5
commit d796945ad8

View File

@ -131,12 +131,28 @@ impl LayerNorm {
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
struct BertEmbeddings { struct BertEmbeddings {
word_embeddings: Embedding, word_embeddings: Embedding,
position_embeddings: Embedding, position_embeddings: Option<Embedding>,
token_type_embeddings: Embedding, token_type_embeddings: Embedding,
layer_norm: LayerNorm,
dropout: Dropout,
position_ids: Tensor, position_ids: Tensor,
token_type_ids: Tensor, token_type_ids: Tensor,
} }
impl BertEmbeddings {
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (input_embeddings + token_type_embeddings)?;
if let Some(position_embeddings) = &self.position_embeddings {
embeddings = (&embeddings + position_embeddings.forward(&embeddings))?
}
let embeddings = self.layer_norm.forward(&embeddings)?;
let embeddings = self.dropout.forward(&embeddings)?;
Ok(embeddings)
}
}
struct BertSelfAttention { struct BertSelfAttention {
query: Linear, query: Linear,
key: Linear, key: Linear,
@ -144,6 +160,12 @@ struct BertSelfAttention {
dropout: Dropout, dropout: Dropout,
} }
impl BertSelfAttention {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
struct BertSelfOutput { struct BertSelfOutput {
dense: Linear, dense: Linear,
layer_norm: LayerNorm, layer_norm: LayerNorm,
@ -165,8 +187,10 @@ struct BertAttention {
} }
impl BertAttention { impl BertAttention {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
todo!() let self_outputs = self.self_attention.forward(hidden_states)?;
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
Ok(attention_output)
} }
} }
@ -206,8 +230,16 @@ struct BertLayer {
} }
impl BertLayer { impl BertLayer {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
todo!() let attention_output = self.attention.forward(hidden_states)?;
// TODO: Support cross-attention?
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
// TODO: Support something similar to `apply_chunking_to_forward`?
let intermediate_output = self.intermediate.forward(&attention_output)?;
let layer_output = self
.output
.forward(&intermediate_output, &attention_output)?;
Ok(layer_output)
} }
} }
@ -217,8 +249,13 @@ struct BertEncoder {
} }
impl BertEncoder { impl BertEncoder {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> { fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
todo!() let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states)?
}
Ok(hidden_states)
} }
} }
@ -229,8 +266,10 @@ struct BertModel {
} }
impl BertModel { impl BertModel {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> { fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
todo!() let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
let sequence_output = self.encoder.forward(&embedding_output)?;
Ok(sequence_output)
} }
} }