mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Fix the device for the bert attention mask. (#2414)
This commit is contained in:
@ -501,5 +501,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
}
|
||||
|
Reference in New Issue
Block a user