Line up the textdecoder values with the python implementation.

This commit is contained in:
laurent
2023-07-04 12:27:05 +01:00
parent aea090401d
commit 0d99b43792

View File

@ -323,13 +323,18 @@ impl MultiHeadAttention {
Ok(out) Ok(out)
} }
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = q.shape().r3()?; let (n_batch, n_ctx, n_state) = x.shape().r3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let (_, _, n_state) = q.shape().r3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25); let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (q.reshape(target_dims)?.transpose(1, 2)? * scale)?; let q = (self.reshape_head(q)? * scale)?;
let k = (k.reshape(target_dims)?.transpose(1, 2)?.transpose(2, 3)? * scale)?; let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
let v = v.reshape(target_dims)?.transpose(1, 2)?; let v = self.reshape_head(v)?.contiguous()?;
let qk = q.matmul(&k)?; let qk = q.matmul(&k)?;
let w = qk.softmax(qk.rank() - 1)?; let w = qk.softmax(qk.rank() - 1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?; let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
@ -341,8 +346,7 @@ impl MultiHeadAttention {
struct ResidualAttentionBlock { struct ResidualAttentionBlock {
attn: MultiHeadAttention, attn: MultiHeadAttention,
attn_ln: LayerNorm, attn_ln: LayerNorm,
cross_attn: Option<MultiHeadAttention>, cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
cross_attn_ln: Option<LayerNorm>,
mlp_linear1: Linear, mlp_linear1: Linear,
mlp_linear2: Linear, mlp_linear2: Linear,
mlp_ln: LayerNorm, mlp_ln: LayerNorm,
@ -352,13 +356,13 @@ impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> { fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?; let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?; let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
let (cross_attn, cross_attn_ln) = if ca { let cross_attn = if ca {
let cross_attn = let cross_attn =
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?; MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?; let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
(Some(cross_attn), Some(cross_attn_ln)) Some((cross_attn, cross_attn_ln))
} else { } else {
(None, None) None
}; };
let n_mlp = n_state * 4; let n_mlp = n_state * 4;
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?; let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
@ -368,7 +372,6 @@ impl ResidualAttentionBlock {
attn, attn,
attn_ln, attn_ln,
cross_attn, cross_attn,
cross_attn_ln,
mlp_linear1, mlp_linear1,
mlp_linear2, mlp_linear2,
mlp_ln, mlp_ln,
@ -379,11 +382,8 @@ impl ResidualAttentionBlock {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?; let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
let mut x = (x + attn)?; let mut x = (x + attn)?;
// Cross-Attn // Cross-Attn
if let Some(cross_attn_ln) = &self.cross_attn_ln { if let Some((attn, ln)) = &self.cross_attn {
x = cross_attn_ln.forward(&x)? x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
}
if let Some(cross_attn) = &self.cross_attn {
x = cross_attn.forward(&x, xa)?
} }
// Mlp // Mlp
let mlp = self.mlp_linear2.forward( let mlp = self.mlp_linear2.forward(
@ -482,7 +482,7 @@ impl TextDecoder {
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?; vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
let blocks = (0..cfg.n_text_layer) let blocks = (0..cfg.n_text_layer)
.map(|i| { .map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?; let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
@ -502,7 +502,7 @@ impl TextDecoder {
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() { for block in self.blocks.iter() {
x = block.forward(&x, Some(xa))? x = block.forward(&x, Some(xa))?;
} }
let x = self.ln.forward(&x)?; let x = self.ln.forward(&x)?;
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?; let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;