Add the causal mask in mixformer. (#937)

This commit is contained in:
Laurent Mazare
2023-09-23 09:50:26 +01:00
committed by GitHub
parent b54acfa3d0
commit 7582937a32

View File

@ -75,6 +75,20 @@ impl Module for Embedding {
} }
} }
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device)
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Debug)] #[derive(Debug)]
struct RotaryEmbedding { struct RotaryEmbedding {
sin: Tensor, sin: Tensor,
@ -198,6 +212,7 @@ struct MHA {
rotary_emb: RotaryEmbedding, rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize, head_dim: usize,
n_head: usize,
softmax_scale: f64, softmax_scale: f64,
span: tracing::Span, span: tracing::Span,
} }
@ -214,6 +229,7 @@ impl MHA {
wqkv, wqkv,
out_proj, out_proj,
head_dim, head_dim,
n_head: cfg.n_head,
kv_cache: None, kv_cache: None,
rotary_emb, rotary_emb,
softmax_scale, softmax_scale,
@ -221,7 +237,7 @@ impl MHA {
}) })
} }
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?; let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self let qkv = self
@ -249,9 +265,16 @@ impl MHA {
let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
// TODO: Add the causal mask.
// causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1) // causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
// scores = scores + causal_mask.to(dtype=scores.dtype) // scores = scores + causal_mask.to(dtype=scores.dtype)
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
&mask.broadcast_left(b_size * self.n_head)?,
f32::NEG_INFINITY,
)?,
};
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
@ -287,11 +310,11 @@ impl ParallelBlock {
}) })
} }
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> { fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let residual = xs; let residual = xs;
let xs = xs.apply(&self.ln)?; let xs = xs.apply(&self.ln)?;
let attn_outputs = self.mixer.forward(&xs)?; let attn_outputs = self.mixer.forward(&xs, mask)?;
let feed_forward_hidden_states = self.mlp.forward(&xs)?; let feed_forward_hidden_states = self.mlp.forward(&xs)?;
attn_outputs + feed_forward_hidden_states + residual attn_outputs + feed_forward_hidden_states + residual
} }
@ -327,8 +350,13 @@ impl MixFormerSequentialForCausalLM {
let _enter = self.span.enter(); let _enter = self.span.enter();
let (_b_size, seq_len) = xs.dims2()?; let (_b_size, seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embedding)?; let mut xs = xs.apply(&self.embedding)?;
let mask = if seq_len <= 1 {
None
} else {
Some(get_mask(seq_len, xs.device())?)
};
for block in self.blocks.iter_mut() { for block in self.blocks.iter_mut() {
xs = block.forward(&xs)? xs = block.forward(&xs, mask.as_ref())?
} }
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
} }