From 7582937a32d3648bf65a170768e4734758a68a93 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 23 Sep 2023 09:50:26 +0100 Subject: [PATCH] Add the causal mask in mixformer. (#937) --- candle-transformers/src/models/mixformer.rs | 38 ++++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 61eaea54..6a3b5515 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -75,6 +75,20 @@ impl Module for Embedding { } } +fn get_mask(size: usize, device: &Device) -> Result { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + #[derive(Debug)] struct RotaryEmbedding { sin: Tensor, @@ -198,6 +212,7 @@ struct MHA { rotary_emb: RotaryEmbedding, kv_cache: Option<(Tensor, Tensor)>, head_dim: usize, + n_head: usize, softmax_scale: f64, span: tracing::Span, } @@ -214,6 +229,7 @@ impl MHA { wqkv, out_proj, head_dim, + n_head: cfg.n_head, kv_cache: None, rotary_emb, softmax_scale, @@ -221,7 +237,7 @@ impl MHA { }) } - fn forward(&mut self, xs: &Tensor) -> Result { + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { let _enter = self.span.enter(); let (b_size, seq_len, _n_embd) = xs.dims3()?; let qkv = self @@ -249,9 +265,16 @@ impl MHA { let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s - // TODO: Add the causal mask. // causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1) // scores = scores + causal_mask.to(dtype=scores.dtype) + let attn_weights = match mask { + None => attn_weights, + Some(mask) => masked_fill( + &attn_weights, + &mask.broadcast_left(b_size * self.n_head)?, + f32::NEG_INFINITY, + )?, + }; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; // output = torch.einsum('bhts,bshd->bthd', attention_drop, v) @@ -287,11 +310,11 @@ impl ParallelBlock { }) } - fn forward(&mut self, xs: &Tensor) -> Result { + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result { let _enter = self.span.enter(); let residual = xs; let xs = xs.apply(&self.ln)?; - let attn_outputs = self.mixer.forward(&xs)?; + let attn_outputs = self.mixer.forward(&xs, mask)?; let feed_forward_hidden_states = self.mlp.forward(&xs)?; attn_outputs + feed_forward_hidden_states + residual } @@ -327,8 +350,13 @@ impl MixFormerSequentialForCausalLM { let _enter = self.span.enter(); let (_b_size, seq_len) = xs.dims2()?; let mut xs = xs.apply(&self.embedding)?; + let mask = if seq_len <= 1 { + None + } else { + Some(get_mask(seq_len, xs.device())?) + }; for block in self.blocks.iter_mut() { - xs = block.forward(&xs)? + xs = block.forward(&xs, mask.as_ref())? } xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) }