Fix hardcoded f32 dtype for attention_mask. Use the model dtype for compatibility. (#2872)

This commit is contained in:
Manpreet Singh
2025-04-08 00:12:14 -04:00
committed by GitHub
parent 2f3bf42bcb
commit d339b01726
2 changed files with 14 additions and 6 deletions

View File

@ -504,8 +504,9 @@ impl BertModel {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
let dtype = embedding_output.dtype();
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
Ok(sequence_output)
}
@ -519,8 +520,11 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
&Tensor::try_from(f32::MIN)?
.to_device(attention_mask.device())?
.to_dtype(dtype)?,
)
}
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766

View File

@ -514,8 +514,9 @@ impl ChineseClipTextTransformer {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
let dtype = embedding_output.dtype();
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
let encoder_output = encoder_outputs.i((.., 0, ..))?;
let pooled_output = match &self.pooler {
@ -535,6 +536,9 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
&Tensor::try_from(f32::MIN)?
.to_device(attention_mask.device())?
.to_dtype(dtype)?,
)
}