mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Use an attention mask in the e5 padding case. (#1085)
This commit is contained in:
@ -50,7 +50,8 @@ if __name__ == "__main__":
|
||||
tokenized = tokenizer(sentences, padding=True)
|
||||
tokens = Tensor(tokenized["input_ids"])
|
||||
token_type_ids = Tensor(tokenized["token_type_ids"])
|
||||
encoder_out, _ = model.forward(tokens, token_type_ids)
|
||||
attention_mask = Tensor(tokenized["attention_mask"])
|
||||
encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
|
||||
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
|
||||
|
@ -46,7 +46,7 @@ class BertSelfAttention(Module):
|
||||
x = x.reshape(new_x_shape).transpose(1, 2)
|
||||
return x.contiguous()
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
query = self.query.forward(hidden_states)
|
||||
key = self.key.forward(hidden_states)
|
||||
value = self.value.forward(hidden_states)
|
||||
@ -56,7 +56,11 @@ class BertSelfAttention(Module):
|
||||
value = self.transpose_for_scores(value)
|
||||
|
||||
attention_scores = query.matmul(key.t())
|
||||
attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
|
||||
attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
|
||||
if attention_mask is not None:
|
||||
b_size, _, _, last_dim = attention_scores.shape
|
||||
attention_scores = attention_scores.broadcast_add(
|
||||
attention_mask.reshape((b_size, 1, 1, last_dim)))
|
||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
||||
|
||||
context_layer = attention_probs.matmul(value)
|
||||
@ -82,8 +86,8 @@ class BertAttention(Module):
|
||||
self.self = BertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
self_outputs = self.self.forward(hidden_states)
|
||||
def forward(self, hidden_states: Tensor, attention_mask: None) -> Tensor:
|
||||
self_outputs = self.self.forward(hidden_states, attention_mask=attention_mask)
|
||||
attention_output = self.output.forward(self_outputs, hidden_states)
|
||||
return attention_output
|
||||
|
||||
@ -117,8 +121,8 @@ class BertLayer(Module):
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
attention_output = self.attention.forward(hidden_states)
|
||||
def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
attention_output = self.attention.forward(hidden_states, attention_mask=attention_mask)
|
||||
# TODO: Support cross-attention?
|
||||
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
# TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
@ -134,9 +138,9 @@ class BertEncoder(Module):
|
||||
for _ in range(config.num_hidden_layers):
|
||||
self.layer.append(BertLayer(config))
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
for l in self.layer:
|
||||
hidden_states = l.forward(hidden_states)
|
||||
hidden_states = l.forward(hidden_states, attention_mask=attention_mask)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -178,6 +182,13 @@ class BertPooler(Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
def masked_fill(on_false: float, mask: Tensor, on_true: float):
|
||||
shape = mask.shape
|
||||
on_true = candle.tensor(on_true).broadcast_as(shape)
|
||||
on_false = candle.tensor(on_false).broadcast_as(shape)
|
||||
return mask.where_cond(on_true, on_false)
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
|
||||
class BertModel(Module):
|
||||
def __init__(self, config: Config, add_pooling_layer=True) -> None:
|
||||
@ -187,8 +198,11 @@ class BertModel(Module):
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
if attention_mask is not None:
|
||||
# Replace 0s with -inf, and 1s with 0s.
|
||||
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)
|
||||
embeddings = self.embeddings.forward(input_ids, token_type_ids)
|
||||
encoder_out = self.encoder.forward(embeddings)
|
||||
encoder_out = self.encoder.forward(embeddings, attention_mask=attention_mask)
|
||||
pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
|
||||
return encoder_out, pooled_output
|
||||
|
Reference in New Issue
Block a user