Add the residual attention block.

This commit is contained in:
laurent
2023-07-04 07:43:36 +01:00
parent b1d42231fb
commit 0ca2af6940

View File

@ -176,14 +176,14 @@ struct LayerNorm {
}
impl LayerNorm {
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps }
}
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?;
Ok(Self::new(weight, bias, eps))
Ok(Self {
weight,
bias,
eps: 1e-5,
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -209,11 +209,11 @@ struct MultiHeadAttention {
}
impl MultiHeadAttention {
fn load(n_head: usize, n: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let query = Linear::load(n, n, &format!("{p}.query"), vb)?;
let value = Linear::load_no_bias(n, n, &format!("{p}.value"), vb)?;
let key = Linear::load(n, n, &format!("{p}.key"), vb)?;
let out = Linear::load(n, n, &format!("{p}.out"), vb)?;
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
let value = Linear::load_no_bias(n_state, n_state, &format!("{p}.value"), vb)?;
let key = Linear::load(n_state, n_state, &format!("{p}.key"), vb)?;
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
Ok(Self {
query,
key,
@ -223,16 +223,16 @@ impl MultiHeadAttention {
})
}
fn forward(&self, x: &Tensor) -> Result<(Tensor, Tensor)> {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let q = self.query.forward(x)?;
let k = self.key.forward(x)?;
let v = self.value.forward(x)?;
let (wv, qk) = self.qkv_attention(&q, &k, &v)?;
let wv = self.qkv_attention(&q, &k, &v)?;
let out = self.out.forward(&wv)?;
Ok((out, qk))
Ok(out)
}
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = q.shape().r3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
@ -242,8 +242,66 @@ impl MultiHeadAttention {
let qk = q.matmul(&k)?;
let w = qk.softmax(qk.rank() - 1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
let qk = qk.detach()?;
Ok((wv, qk))
Ok(wv)
}
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
cross_attn: Option<MultiHeadAttention>,
cross_attn_ln: Option<LayerNorm>,
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
}
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
let (cross_attn, cross_attn_ln) = if ca {
let cross_attn =
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
(Some(cross_attn), Some(cross_attn_ln))
} else {
(None, None)
};
let n_mlp = n_state * 4;
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?;
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?;
Ok(Self {
attn,
attn_ln,
cross_attn,
cross_attn_ln,
mlp_linear1,
mlp_linear2,
mlp_ln,
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?)?;
let mut x = (x + attn)?;
// Cross-Attn
if let Some(cross_attn) = &self.cross_attn {
x = cross_attn.forward(&x)?
}
if let Some(cross_attn_ln) = &self.cross_attn_ln {
x = cross_attn_ln.forward(&x)?
}
// Mlp
let mlp = self.mlp_linear2.forward(
&self
.mlp_linear1
.forward(&self.mlp_ln.forward(&x)?)?
.gelu()?,
)?;
Ok((x + mlp)?)
}
}