mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Bugfix for transpose.
This commit is contained in:
@ -342,8 +342,9 @@ impl BertSelfAttention {
|
||||
new_x_shape.pop();
|
||||
new_x_shape.push(self.num_attention_heads);
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
Ok(xs)
|
||||
// Be cautious about the transposition if adding a batch dim!
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(0, 1)?;
|
||||
Ok(xs.contiguous()?)
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
@ -361,7 +362,7 @@ impl BertSelfAttention {
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.transpose(0, 1)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
|
Reference in New Issue
Block a user