mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add a simpler way to specify the dim index for some ops.
This commit is contained in:
@ -386,12 +386,12 @@ impl BertSelfAttention {
|
||||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_probs = attention_scores.softmax(attention_scores.rank() - 1)?;
|
||||
let attention_probs = attention_scores.softmax(candle::D::Minus1)?;
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
|
||||
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user