Add the causal mask in text decoding.

This commit is contained in:
laurent
2023-07-04 15:25:47 +01:00
parent 04f4ef81e8
commit 31663bc04f

View File

@ -316,11 +316,11 @@ impl MultiHeadAttention {
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let q = self.query.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
let wv = self.qkv_attention(&q, &k, &v)?;
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;
Ok(out)
}
@ -331,13 +331,23 @@ impl MultiHeadAttention {
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let (_, _, n_state) = q.shape().r3()?;
fn qkv_attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let (_, n_ctx, n_state) = q.shape().r3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let qk = q.matmul(&k)?;
let mut qk = q.matmul(&k)?;
if let Some(mask) = mask {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(qk.rank() - 1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
Ok(wv)
@ -380,11 +390,11 @@ impl ResidualAttentionBlock {
})
}
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
let mut x = (x + attn)?;
if let Some((attn, ln)) = &self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
}
let mlp = self.mlp_linear2.forward(
&self
@ -456,7 +466,7 @@ impl AudioEncoder {
let x = x.transpose(1, 2)?;
let mut x = x.broadcast_add(&self.positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, None)?
x = block.forward(&x, None, None)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
@ -469,6 +479,7 @@ struct TextDecoder {
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
}
impl TextDecoder {
@ -486,11 +497,17 @@ impl TextDecoder {
})
.collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?;
Ok(Self {
token_embedding,
positional_embedding,
blocks,
ln,
mask,
})
}
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
@ -500,7 +517,7 @@ impl TextDecoder {
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa))?;
x = block.forward(&x, Some(xa), Some(&self.mask))?;
}
let x = self.ln.forward(&x)?;
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;