Avoid broadcasting on the batch dimension for the attention mask. (#1920)

This commit is contained in:
Laurent Mazare
2024-03-23 13:08:53 +01:00
committed by GitHub
parent cc856db9ce
commit 6f877592a7
2 changed files with 6 additions and 8 deletions

View File

@ -385,7 +385,6 @@ impl Model {
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
@ -408,16 +407,16 @@ impl Model {
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let (_b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;