From c09aa4b0f4c1fb074aecbf71aea3e96cadebea60 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 4 Jul 2023 09:14:28 +0100 Subject: [PATCH] Add more to the forward pass and fix the cross-attention. --- candle-examples/examples/whisper/main.rs | 37 +++++++++++++++++------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 1930a7a9..129c73b8 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -225,10 +225,10 @@ impl MultiHeadAttention { }) } - fn forward(&self, x: &Tensor) -> Result { + fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result { let q = self.query.forward(x)?; - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; + let k = self.key.forward(xa.unwrap_or(x))?; + let v = self.value.forward(xa.unwrap_or(x))?; let wv = self.qkv_attention(&q, &k, &v)?; let out = self.out.forward(&wv)?; Ok(out) @@ -286,16 +286,16 @@ impl ResidualAttentionBlock { }) } - fn forward(&self, x: &Tensor) -> Result { - let attn = self.attn.forward(&self.attn_ln.forward(x)?)?; + fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result { + let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?; let mut x = (x + attn)?; // Cross-Attn - if let Some(cross_attn) = &self.cross_attn { - x = cross_attn.forward(&x)? - } if let Some(cross_attn_ln) = &self.cross_attn_ln { x = cross_attn_ln.forward(&x)? } + if let Some(cross_attn) = &self.cross_attn { + x = cross_attn.forward(&x, xa)? + } // Mlp let mlp = self.mlp_linear2.forward( &self @@ -358,7 +358,7 @@ impl AudioEncoder { let x = x.transpose(1, 2)?; let mut x = x.broadcast_add(&self.positional_embedding)?; for block in self.blocks.iter() { - x = block.forward(&x)? + x = block.forward(&x, None)? } let x = self.ln_post.forward(&x)?; Ok(x) @@ -368,6 +368,7 @@ impl AudioEncoder { // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 struct TextDecoder { token_embedding: Embedding, + positional_embedding: Tensor, blocks: Vec, ln: LayerNorm, mask: Tensor, @@ -377,8 +378,11 @@ impl TextDecoder { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { let n_state = cfg.n_text_state; let n_head = cfg.n_text_head; + let n_ctx = cfg.n_text_ctx; let token_embedding = Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?; + let positional_embedding = + vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?; let blocks = (0..cfg.n_text_layer) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) @@ -388,13 +392,24 @@ impl TextDecoder { let mask = Tensor::new(&[0u32], &vb.device)?; // TODO Ok(Self { token_embedding, + positional_embedding, blocks, ln, mask, }) } - fn forward(&self, _tokens: &Tensor, _enc: &Tensor) -> Result { - todo!() + fn forward(&self, x: &Tensor, xa: &Tensor) -> Result { + let x_dims = x.dims(); + let last = x_dims[x_dims.len() - 1]; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = (token_embedding + positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x, Some(xa))? + } + let x = self.ln.forward(&x)?; + let logits = x.matmul(&self.token_embedding.embeddings.t()?)?; + Ok(logits) } }