diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 42486a2d..2262aa1a 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -501,5 +501,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result< }; let attention_mask = attention_mask.to_dtype(dtype)?; // torch.finfo(dtype).min - (attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?) + (attention_mask.ones_like()? - &attention_mask)? + .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) }