From cf7d7fcf2f20c24aae633483c3a107c1219a7f9a Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 24 Mar 2024 19:04:32 +0100 Subject: [PATCH] Also avoid the mask in the llama example. --- candle-transformers/src/models/llama.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index c311d4c4..73671cdc 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -240,8 +240,12 @@ impl CausalSelfAttention { let k = k.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = if seq_len == 1 { + att + } else { + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; + masked_fill(&att, &mask, f32::NEG_INFINITY)? + }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?