mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add the causal mask in text decoding.
This commit is contained in:
@ -316,11 +316,11 @@ impl MultiHeadAttention {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||
let wv = self.qkv_attention(&q, &k, &v)?;
|
||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
}
|
||||
@ -331,13 +331,23 @@ impl MultiHeadAttention {
|
||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||
}
|
||||
|
||||
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let (_, _, n_state) = q.shape().r3()?;
|
||||
fn qkv_attention(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (_, n_ctx, n_state) = q.shape().r3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let qk = q.matmul(&k)?;
|
||||
let mut qk = q.matmul(&k)?;
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = qk.softmax(qk.rank() - 1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
|
||||
Ok(wv)
|
||||
@ -380,11 +390,11 @@ impl ResidualAttentionBlock {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
@ -456,7 +466,7 @@ impl AudioEncoder {
|
||||
let x = x.transpose(1, 2)?;
|
||||
let mut x = x.broadcast_add(&self.positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, None)?
|
||||
x = block.forward(&x, None, None)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
@ -469,6 +479,7 @@ struct TextDecoder {
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
@ -486,11 +497,17 @@ impl TextDecoder {
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||
let mask: Vec<_> = (0..n_ctx)
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?;
|
||||
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
})
|
||||
}
|
||||
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||
@ -500,7 +517,7 @@ impl TextDecoder {
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, Some(xa))?;
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask))?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
||||
|
Reference in New Issue
Block a user