mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Avoid broadcasting on the batch dimension for the attention mask. (#1920)
This commit is contained in:
@ -385,7 +385,6 @@ impl Model {
|
|||||||
|
|
||||||
fn prepare_decoder_attention_mask(
|
fn prepare_decoder_attention_mask(
|
||||||
&self,
|
&self,
|
||||||
b_size: usize,
|
|
||||||
tgt_len: usize,
|
tgt_len: usize,
|
||||||
seqlen_offset: usize,
|
seqlen_offset: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
@ -408,16 +407,16 @@ impl Model {
|
|||||||
} else {
|
} else {
|
||||||
mask
|
mask
|
||||||
};
|
};
|
||||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
.to_dtype(self.dtype)
|
.to_dtype(self.dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
let (b_size, seq_len) = input_ids.dims2()?;
|
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||||
let attention_mask = if seq_len <= 1 {
|
let attention_mask = if seq_len <= 1 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
|
||||||
Some(mask)
|
Some(mask)
|
||||||
};
|
};
|
||||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||||
|
@ -287,7 +287,6 @@ impl Model {
|
|||||||
|
|
||||||
fn prepare_decoder_attention_mask(
|
fn prepare_decoder_attention_mask(
|
||||||
&self,
|
&self,
|
||||||
b_size: usize,
|
|
||||||
tgt_len: usize,
|
tgt_len: usize,
|
||||||
seqlen_offset: usize,
|
seqlen_offset: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
@ -310,16 +309,16 @@ impl Model {
|
|||||||
} else {
|
} else {
|
||||||
mask
|
mask
|
||||||
};
|
};
|
||||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
.to_dtype(DType::F32)
|
.to_dtype(DType::F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
let (b_size, seq_len) = input_ids.dims2()?;
|
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||||
let attention_mask = if seq_len <= 1 {
|
let attention_mask = if seq_len <= 1 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
|
||||||
Some(mask)
|
Some(mask)
|
||||||
};
|
};
|
||||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||||
|
Reference in New Issue
Block a user