mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add more to the forward pass and fix the cross-attention.
This commit is contained in:
@ -225,10 +225,10 @@ impl MultiHeadAttention {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||
let wv = self.qkv_attention(&q, &k, &v)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
@ -286,16 +286,16 @@ impl ResidualAttentionBlock {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?)?;
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
||||
let mut x = (x + attn)?;
|
||||
// Cross-Attn
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
x = cross_attn.forward(&x)?
|
||||
}
|
||||
if let Some(cross_attn_ln) = &self.cross_attn_ln {
|
||||
x = cross_attn_ln.forward(&x)?
|
||||
}
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
x = cross_attn.forward(&x, xa)?
|
||||
}
|
||||
// Mlp
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
@ -358,7 +358,7 @@ impl AudioEncoder {
|
||||
let x = x.transpose(1, 2)?;
|
||||
let mut x = x.broadcast_add(&self.positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x)?
|
||||
x = block.forward(&x, None)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
@ -368,6 +368,7 @@ impl AudioEncoder {
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
struct TextDecoder {
|
||||
token_embedding: Embedding,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
@ -377,8 +378,11 @@ impl TextDecoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let n_state = cfg.n_text_state;
|
||||
let n_head = cfg.n_text_head;
|
||||
let n_ctx = cfg.n_text_ctx;
|
||||
let token_embedding =
|
||||
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
|
||||
let positional_embedding =
|
||||
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
|
||||
let blocks = (0..cfg.n_text_layer)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||
@ -388,13 +392,24 @@ impl TextDecoder {
|
||||
let mask = Tensor::new(&[0u32], &vb.device)?; // TODO
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
})
|
||||
}
|
||||
fn forward(&self, _tokens: &Tensor, _enc: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||
let x_dims = x.dims();
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = (token_embedding + positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, Some(xa))?
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let logits = x.matmul(&self.token_embedding.embeddings.t()?)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user