mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add the residual attention block.
This commit is contained in:
@ -176,14 +176,14 @@ struct LayerNorm {
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||
Ok(Self::new(weight, bias, eps))
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias,
|
||||
eps: 1e-5,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -209,11 +209,11 @@ struct MultiHeadAttention {
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_head: usize, n: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let query = Linear::load(n, n, &format!("{p}.query"), vb)?;
|
||||
let value = Linear::load_no_bias(n, n, &format!("{p}.value"), vb)?;
|
||||
let key = Linear::load(n, n, &format!("{p}.key"), vb)?;
|
||||
let out = Linear::load(n, n, &format!("{p}.out"), vb)?;
|
||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
|
||||
let value = Linear::load_no_bias(n_state, n_state, &format!("{p}.value"), vb)?;
|
||||
let key = Linear::load(n_state, n_state, &format!("{p}.key"), vb)?;
|
||||
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -223,16 +223,16 @@ impl MultiHeadAttention {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
let (wv, qk) = self.qkv_attention(&q, &k, &v)?;
|
||||
let wv = self.qkv_attention(&q, &k, &v)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok((out, qk))
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = q.shape().r3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
@ -242,8 +242,66 @@ impl MultiHeadAttention {
|
||||
let qk = q.matmul(&k)?;
|
||||
let w = qk.softmax(qk.rank() - 1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
|
||||
let qk = qk.detach()?;
|
||||
Ok((wv, qk))
|
||||
Ok(wv)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
struct ResidualAttentionBlock {
|
||||
attn: MultiHeadAttention,
|
||||
attn_ln: LayerNorm,
|
||||
cross_attn: Option<MultiHeadAttention>,
|
||||
cross_attn_ln: Option<LayerNorm>,
|
||||
mlp_linear1: Linear,
|
||||
mlp_linear2: Linear,
|
||||
mlp_ln: LayerNorm,
|
||||
}
|
||||
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
|
||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
|
||||
let (cross_attn, cross_attn_ln) = if ca {
|
||||
let cross_attn =
|
||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
|
||||
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
|
||||
(Some(cross_attn), Some(cross_attn_ln))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
|
||||
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?;
|
||||
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
attn_ln,
|
||||
cross_attn,
|
||||
cross_attn_ln,
|
||||
mlp_linear1,
|
||||
mlp_linear2,
|
||||
mlp_ln,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?)?;
|
||||
let mut x = (x + attn)?;
|
||||
// Cross-Attn
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
x = cross_attn.forward(&x)?
|
||||
}
|
||||
if let Some(cross_attn_ln) = &self.cross_attn_ln {
|
||||
x = cross_attn_ln.forward(&x)?
|
||||
}
|
||||
// Mlp
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
.mlp_linear1
|
||||
.forward(&self.mlp_ln.forward(&x)?)?
|
||||
.gelu()?,
|
||||
)?;
|
||||
Ok((x + mlp)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user