mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Fix hardcoded f32 dtype for attention_mask. Use the model dtype for compatibility. (#2872)
This commit is contained in:
@ -504,8 +504,9 @@ impl BertModel {
|
||||
Some(attention_mask) => attention_mask.clone(),
|
||||
None => input_ids.ones_like()?,
|
||||
};
|
||||
let dtype = embedding_output.dtype();
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
@ -519,8 +520,11 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||
&Tensor::try_from(f32::MIN)?
|
||||
.to_device(attention_mask.device())?
|
||||
.to_dtype(dtype)?,
|
||||
)
|
||||
}
|
||||
|
||||
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
||||
|
@ -514,8 +514,9 @@ impl ChineseClipTextTransformer {
|
||||
Some(attention_mask) => attention_mask.clone(),
|
||||
None => input_ids.ones_like()?,
|
||||
};
|
||||
let dtype = embedding_output.dtype();
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
||||
let pooled_output = match &self.pooler {
|
||||
@ -535,6 +536,9 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||
&Tensor::try_from(f32::MIN)?
|
||||
.to_device(attention_mask.device())?
|
||||
.to_dtype(dtype)?,
|
||||
)
|
||||
}
|
||||
|
Reference in New Issue
Block a user