Use an attention mask in the e5 padding case. (#1085)

This commit is contained in:
Laurent Mazare
2023-10-13 19:53:40 +02:00
committed by GitHub
parent 07af87a1d8
commit 75989fc3b7
2 changed files with 26 additions and 11 deletions

View File

@ -50,7 +50,8 @@ if __name__ == "__main__":
tokenized = tokenizer(sentences, padding=True) tokenized = tokenizer(sentences, padding=True)
tokens = Tensor(tokenized["input_ids"]) tokens = Tensor(tokenized["input_ids"])
token_type_ids = Tensor(tokenized["token_type_ids"]) token_type_ids = Tensor(tokenized["token_type_ids"])
encoder_out, _ = model.forward(tokens, token_type_ids) attention_mask = Tensor(tokenized["attention_mask"])
encoder_out, _ = model.forward(tokens, token_type_ids, attention_mask=attention_mask)
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt") hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
hf_result = hf_model(**hf_tokenized)["last_hidden_state"] hf_result = hf_model(**hf_tokenized)["last_hidden_state"]

View File

@ -46,7 +46,7 @@ class BertSelfAttention(Module):
x = x.reshape(new_x_shape).transpose(1, 2) x = x.reshape(new_x_shape).transpose(1, 2)
return x.contiguous() return x.contiguous()
def forward(self, hidden_states: Tensor) -> Tensor: def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
query = self.query.forward(hidden_states) query = self.query.forward(hidden_states)
key = self.key.forward(hidden_states) key = self.key.forward(hidden_states)
value = self.value.forward(hidden_states) value = self.value.forward(hidden_states)
@ -56,7 +56,11 @@ class BertSelfAttention(Module):
value = self.transpose_for_scores(value) value = self.transpose_for_scores(value)
attention_scores = query.matmul(key.t()) attention_scores = query.matmul(key.t())
attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5) attention_scores = attention_scores / float(self.attention_head_size) ** 0.5
if attention_mask is not None:
b_size, _, _, last_dim = attention_scores.shape
attention_scores = attention_scores.broadcast_add(
attention_mask.reshape((b_size, 1, 1, last_dim)))
attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = F.softmax(attention_scores, dim=-1)
context_layer = attention_probs.matmul(value) context_layer = attention_probs.matmul(value)
@ -82,8 +86,8 @@ class BertAttention(Module):
self.self = BertSelfAttention(config) self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
def forward(self, hidden_states: Tensor) -> Tensor: def forward(self, hidden_states: Tensor, attention_mask: None) -> Tensor:
self_outputs = self.self.forward(hidden_states) self_outputs = self.self.forward(hidden_states, attention_mask=attention_mask)
attention_output = self.output.forward(self_outputs, hidden_states) attention_output = self.output.forward(self_outputs, hidden_states)
return attention_output return attention_output
@ -117,8 +121,8 @@ class BertLayer(Module):
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states: Tensor) -> Tensor: def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
attention_output = self.attention.forward(hidden_states) attention_output = self.attention.forward(hidden_states, attention_mask=attention_mask)
# TODO: Support cross-attention? # TODO: Support cross-attention?
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
# TODO: Support something similar to `apply_chunking_to_forward`? # TODO: Support something similar to `apply_chunking_to_forward`?
@ -134,9 +138,9 @@ class BertEncoder(Module):
for _ in range(config.num_hidden_layers): for _ in range(config.num_hidden_layers):
self.layer.append(BertLayer(config)) self.layer.append(BertLayer(config))
def forward(self, hidden_states: Tensor) -> Tensor: def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
for l in self.layer: for l in self.layer:
hidden_states = l.forward(hidden_states) hidden_states = l.forward(hidden_states, attention_mask=attention_mask)
return hidden_states return hidden_states
@ -178,6 +182,13 @@ class BertPooler(Module):
return pooled_output return pooled_output
def masked_fill(on_false: float, mask: Tensor, on_true: float):
shape = mask.shape
on_true = candle.tensor(on_true).broadcast_as(shape)
on_false = candle.tensor(on_false).broadcast_as(shape)
return mask.where_cond(on_true, on_false)
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
class BertModel(Module): class BertModel(Module):
def __init__(self, config: Config, add_pooling_layer=True) -> None: def __init__(self, config: Config, add_pooling_layer=True) -> None:
@ -187,8 +198,11 @@ class BertModel(Module):
self.encoder = BertEncoder(config) self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None self.pooler = BertPooler(config) if add_pooling_layer else None
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]: def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask=None) -> Tuple[Tensor, Optional[Tensor]]:
if attention_mask is not None:
# Replace 0s with -inf, and 1s with 0s.
attention_mask = masked_fill(float("-inf"), attention_mask, 1.0)
embeddings = self.embeddings.forward(input_ids, token_type_ids) embeddings = self.embeddings.forward(input_ids, token_type_ids)
encoder_out = self.encoder.forward(embeddings) encoder_out = self.encoder.forward(embeddings, attention_mask=attention_mask)
pooled_output = self.pooler(encoder_out) if self.pooler is not None else None pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
return encoder_out, pooled_output return encoder_out, pooled_output