From 0d99b4379224838265bf9bedd836dd655b94a001 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 12:27:05 +0100 Subject: [PATCH] Line up the textdecoder values with the python implementation. --- candle-examples/examples/whisper/main.rs | 36 ++++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c3a2769f..6341c5ee 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -323,13 +323,18 @@ impl MultiHeadAttention { Ok(out) } - fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result { - let (n_batch, n_ctx, n_state) = q.shape().r3()?; + fn reshape_head(&self, x: &Tensor) -> Result { + let (n_batch, n_ctx, n_state) = x.shape().r3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + Ok(x.reshape(target_dims)?.transpose(1, 2)?) + } + + fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result { + let (_, _, n_state) = q.shape().r3()?; let scale = ((n_state / self.n_head) as f64).powf(-0.25); - let q = (q.reshape(target_dims)?.transpose(1, 2)? * scale)?; - let k = (k.reshape(target_dims)?.transpose(1, 2)?.transpose(2, 3)? * scale)?; - let v = v.reshape(target_dims)?.transpose(1, 2)?; + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; let qk = q.matmul(&k)?; let w = qk.softmax(qk.rank() - 1)?; let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?; @@ -341,8 +346,7 @@ impl MultiHeadAttention { struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, - cross_attn: Option, - cross_attn_ln: Option, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, mlp_linear1: Linear, mlp_linear2: Linear, mlp_ln: LayerNorm, @@ -352,13 +356,13 @@ impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result { let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?; let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?; - let (cross_attn, cross_attn_ln) = if ca { + let cross_attn = if ca { let cross_attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?; let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?; - (Some(cross_attn), Some(cross_attn_ln)) + Some((cross_attn, cross_attn_ln)) } else { - (None, None) + None }; let n_mlp = n_state * 4; let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?; @@ -368,7 +372,6 @@ impl ResidualAttentionBlock { attn, attn_ln, cross_attn, - cross_attn_ln, mlp_linear1, mlp_linear2, mlp_ln, @@ -379,11 +382,8 @@ impl ResidualAttentionBlock { let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?; let mut x = (x + attn)?; // Cross-Attn - if let Some(cross_attn_ln) = &self.cross_attn_ln { - x = cross_attn_ln.forward(&x)? - } - if let Some(cross_attn) = &self.cross_attn { - x = cross_attn.forward(&x, xa)? + if let Some((attn, ln)) = &self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?; } // Mlp let mlp = self.mlp_linear2.forward( @@ -482,7 +482,7 @@ impl TextDecoder { vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?; let blocks = (0..cfg.n_text_layer) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) + ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb) }) .collect::>>()?; let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?; @@ -502,7 +502,7 @@ impl TextDecoder { let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?; for block in self.blocks.iter() { - x = block.forward(&x, Some(xa))? + x = block.forward(&x, Some(xa))?; } let x = self.ln.forward(&x)?; let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;