diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 717e6771..22ddfd2f 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -181,7 +181,12 @@ impl LayerWeights { Ok(rope) } - fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result { + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result { let _enter = self.span_attn.enter(); let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.attention_wq.forward(x)?; @@ -220,8 +225,13 @@ impl LayerWeights { let v = self.repeat_kv(v)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = mask.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, &self.neg_inf)?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; @@ -474,14 +484,18 @@ impl ModelWeights { pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len, x.device())?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, x.device())?) + }; let _enter = self.span.enter(); let mut layer_in = self.tok_embeddings.forward(x)?; for layer in self.layers.iter_mut() { let x = layer_in; let residual = &x; let x = layer.attention_norm.forward(&x)?; - let attn = layer.forward_attn(&x, &mask, index_pos)?; + let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?; let x = (attn + residual)?; // MLP