Add more to the forward pass and fix the cross-attention.

This commit is contained in:
laurent
2023-07-04 09:14:28 +01:00
parent 998cfda9c4
commit c09aa4b0f4

View File

@ -225,10 +225,10 @@ impl MultiHeadAttention {
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
let q = self.query.forward(x)?;
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
let k = self.key.forward(xa.unwrap_or(x))?;
let v = self.value.forward(xa.unwrap_or(x))?;
let wv = self.qkv_attention(&q, &k, &v)?;
let out = self.out.forward(&wv)?;
Ok(out)
@ -286,16 +286,16 @@ impl ResidualAttentionBlock {
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?)?;
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
let mut x = (x + attn)?;
// Cross-Attn
if let Some(cross_attn) = &self.cross_attn {
x = cross_attn.forward(&x)?
}
if let Some(cross_attn_ln) = &self.cross_attn_ln {
x = cross_attn_ln.forward(&x)?
}
if let Some(cross_attn) = &self.cross_attn {
x = cross_attn.forward(&x, xa)?
}
// Mlp
let mlp = self.mlp_linear2.forward(
&self
@ -358,7 +358,7 @@ impl AudioEncoder {
let x = x.transpose(1, 2)?;
let mut x = x.broadcast_add(&self.positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x)?
x = block.forward(&x, None)?
}
let x = self.ln_post.forward(&x)?;
Ok(x)
@ -368,6 +368,7 @@ impl AudioEncoder {
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm,
mask: Tensor,
@ -377,8 +378,11 @@ impl TextDecoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.n_text_state;
let n_head = cfg.n_text_head;
let n_ctx = cfg.n_text_ctx;
let token_embedding =
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
let positional_embedding =
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
let blocks = (0..cfg.n_text_layer)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
@ -388,13 +392,24 @@ impl TextDecoder {
let mask = Tensor::new(&[0u32], &vb.device)?; // TODO
Ok(Self {
token_embedding,
positional_embedding,
blocks,
ln,
mask,
})
}
fn forward(&self, _tokens: &Tensor, _enc: &Tensor) -> Result<Tensor> {
todo!()
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
let x_dims = x.dims();
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = (token_embedding + positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa))?
}
let x = self.ln.forward(&x)?;
let logits = x.matmul(&self.token_embedding.embeddings.t()?)?;
Ok(logits)
}
}