diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 234e79d8..78ca97d3 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -342,8 +342,9 @@ impl BertSelfAttention { new_x_shape.pop(); new_x_shape.push(self.num_attention_heads); new_x_shape.push(self.attention_head_size); - let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; - Ok(xs) + // Be cautious about the transposition if adding a batch dim! + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(0, 1)?; + Ok(xs.contiguous()?) } fn forward(&self, hidden_states: &Tensor) -> Result { @@ -361,7 +362,7 @@ impl BertSelfAttention { let attention_probs = self.dropout.forward(&attention_probs)?; let context_layer = attention_probs.matmul(&value_layer)?; - let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.transpose(0, 1)?.contiguous()?; let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?; Ok(context_layer) }