Line up the textdecoder values with the python implementation.

This commit is contained in:
laurent
2023-07-04 12:27:05 +01:00
parent aea090401d
commit 0d99b43792

View File

@ -323,13 +323,18 @@ impl MultiHeadAttention {
Ok(out)
}
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = q.shape().r3()?;
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let (_, _, n_state) = q.shape().r3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (q.reshape(target_dims)?.transpose(1, 2)? * scale)?;
let k = (k.reshape(target_dims)?.transpose(1, 2)?.transpose(2, 3)? * scale)?;
let v = v.reshape(target_dims)?.transpose(1, 2)?;
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = self.reshape_head(v)?.contiguous()?;
let qk = q.matmul(&k)?;
let w = qk.softmax(qk.rank() - 1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
@ -341,8 +346,7 @@ impl MultiHeadAttention {
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
cross_attn: Option<MultiHeadAttention>,
cross_attn_ln: Option<LayerNorm>,
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
mlp_linear1: Linear,
mlp_linear2: Linear,
mlp_ln: LayerNorm,
@ -352,13 +356,13 @@ impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
let (cross_attn, cross_attn_ln) = if ca {
let cross_attn = if ca {
let cross_attn =
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
(Some(cross_attn), Some(cross_attn_ln))
Some((cross_attn, cross_attn_ln))
} else {
(None, None)
None
};
let n_mlp = n_state * 4;
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
@ -368,7 +372,6 @@ impl ResidualAttentionBlock {
attn,
attn_ln,
cross_attn,
cross_attn_ln,
mlp_linear1,
mlp_linear2,
mlp_ln,
@ -379,11 +382,8 @@ impl ResidualAttentionBlock {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
let mut x = (x + attn)?;
// Cross-Attn
if let Some(cross_attn_ln) = &self.cross_attn_ln {
x = cross_attn_ln.forward(&x)?
}
if let Some(cross_attn) = &self.cross_attn {
x = cross_attn.forward(&x, xa)?
if let Some((attn, ln)) = &self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
}
// Mlp
let mlp = self.mlp_linear2.forward(
@ -482,7 +482,7 @@ impl TextDecoder {
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
let blocks = (0..cfg.n_text_layer)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
})
.collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
@ -502,7 +502,7 @@ impl TextDecoder {
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa))?
x = block.forward(&x, Some(xa))?;
}
let x = self.ln.forward(&x)?;
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;