Bugfix for transpose.

This commit is contained in:
laurent
2023-07-03 17:06:23 +01:00
parent a7f03a7bb6
commit 1ea6690557

View File

@ -342,8 +342,9 @@ impl BertSelfAttention {
new_x_shape.pop();
new_x_shape.push(self.num_attention_heads);
new_x_shape.push(self.attention_head_size);
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
Ok(xs)
// Be cautious about the transposition if adding a batch dim!
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(0, 1)?;
Ok(xs.contiguous()?)
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
@ -361,7 +362,7 @@ impl BertSelfAttention {
let attention_probs = self.dropout.forward(&attention_probs)?;
let context_layer = attention_probs.matmul(&value_layer)?;
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
let context_layer = context_layer.transpose(0, 1)?.contiguous()?;
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
Ok(context_layer)
}