mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Complete (?) the forward pass.
This commit is contained in:
@ -158,11 +158,37 @@ struct BertSelfAttention {
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
dropout: Dropout,
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut new_x_shape = xs.dims().to_vec();
|
||||
new_x_shape.pop();
|
||||
new_x_shape.push(self.num_attention_heads);
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let query_layer = self.query.forward(hidden_states)?;
|
||||
let key_layer = self.key.forward(hidden_states)?;
|
||||
let value_layer = self.value.forward(hidden_states)?;
|
||||
|
||||
let query_layer = self.transpose_for_scores(&query_layer)?;
|
||||
let key_layer = self.transpose_for_scores(&key_layer)?;
|
||||
let value_layer = self.transpose_for_scores(&value_layer)?;
|
||||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_probs = attention_scores.softmax(attention_scores.rank() - 1)?;
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user