From 12ac9e146090a20334e309a64aa272f4f89a7f46 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 13:33:32 +0100 Subject: [PATCH] Complete (?) the forward pass. --- candle-examples/examples/bert/main.rs | 30 +++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index fce7f0f1..d3b0e45f 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -158,11 +158,37 @@ struct BertSelfAttention { key: Linear, value: Linear, dropout: Dropout, + num_attention_heads: usize, + attention_head_size: usize, } impl BertSelfAttention { - fn forward(&self, _xs: &Tensor) -> Result { - todo!() + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let mut new_x_shape = xs.dims().to_vec(); + new_x_shape.pop(); + new_x_shape.push(self.num_attention_heads); + new_x_shape.push(self.attention_head_size); + xs.reshape(new_x_shape.as_slice())?.transpose(1, 2) + } + + fn forward(&self, hidden_states: &Tensor) -> Result { + let query_layer = self.query.forward(hidden_states)?; + let key_layer = self.key.forward(hidden_states)?; + let value_layer = self.value.forward(hidden_states)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_probs = attention_scores.softmax(attention_scores.rank() - 1)?; + let attention_probs = self.dropout.forward(&attention_probs)?; + + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?; + Ok(context_layer) } }