mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
PyO3: Add CI (#1135)
* Add PyO3 ci * Update python.yml * Format `bert.py`
This commit is contained in:
@ -59,8 +59,7 @@ class BertSelfAttention(Module):
|
||||
attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
|
||||
if attention_mask is not None:
|
||||
b_size, _, _, last_dim = attention_scores.shape
|
||||
attention_scores = attention_scores.broadcast_add(
|
||||
attention_mask.reshape((b_size, 1, 1, last_dim)))
|
||||
attention_scores = attention_scores.broadcast_add(attention_mask.reshape((b_size, 1, 1, last_dim)))
|
||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
||||
|
||||
context_layer = attention_probs.matmul(value)
|
||||
@ -198,7 +197,9 @@ class BertModel(Module):
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
def forward(
|
||||
self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
if attention_mask is not None:
|
||||
# Replace 0s with -inf, and 1s with 0s.
|
||||
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)
|
||||
|
Reference in New Issue
Block a user