mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Line up the textdecoder values with the python implementation.
This commit is contained in:
@ -323,13 +323,18 @@ impl MultiHeadAttention {
|
|||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let (n_batch, n_ctx, n_state) = q.shape().r3()?;
|
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
|
||||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||||
|
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||||
|
let (_, _, n_state) = q.shape().r3()?;
|
||||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||||
let q = (q.reshape(target_dims)?.transpose(1, 2)? * scale)?;
|
let q = (self.reshape_head(q)? * scale)?;
|
||||||
let k = (k.reshape(target_dims)?.transpose(1, 2)?.transpose(2, 3)? * scale)?;
|
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||||
let v = v.reshape(target_dims)?.transpose(1, 2)?;
|
let v = self.reshape_head(v)?.contiguous()?;
|
||||||
let qk = q.matmul(&k)?;
|
let qk = q.matmul(&k)?;
|
||||||
let w = qk.softmax(qk.rank() - 1)?;
|
let w = qk.softmax(qk.rank() - 1)?;
|
||||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
|
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
|
||||||
@ -341,8 +346,7 @@ impl MultiHeadAttention {
|
|||||||
struct ResidualAttentionBlock {
|
struct ResidualAttentionBlock {
|
||||||
attn: MultiHeadAttention,
|
attn: MultiHeadAttention,
|
||||||
attn_ln: LayerNorm,
|
attn_ln: LayerNorm,
|
||||||
cross_attn: Option<MultiHeadAttention>,
|
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
||||||
cross_attn_ln: Option<LayerNorm>,
|
|
||||||
mlp_linear1: Linear,
|
mlp_linear1: Linear,
|
||||||
mlp_linear2: Linear,
|
mlp_linear2: Linear,
|
||||||
mlp_ln: LayerNorm,
|
mlp_ln: LayerNorm,
|
||||||
@ -352,13 +356,13 @@ impl ResidualAttentionBlock {
|
|||||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
|
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
|
||||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
|
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
|
||||||
let (cross_attn, cross_attn_ln) = if ca {
|
let cross_attn = if ca {
|
||||||
let cross_attn =
|
let cross_attn =
|
||||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
|
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
|
||||||
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
|
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
|
||||||
(Some(cross_attn), Some(cross_attn_ln))
|
Some((cross_attn, cross_attn_ln))
|
||||||
} else {
|
} else {
|
||||||
(None, None)
|
None
|
||||||
};
|
};
|
||||||
let n_mlp = n_state * 4;
|
let n_mlp = n_state * 4;
|
||||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
|
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
|
||||||
@ -368,7 +372,6 @@ impl ResidualAttentionBlock {
|
|||||||
attn,
|
attn,
|
||||||
attn_ln,
|
attn_ln,
|
||||||
cross_attn,
|
cross_attn,
|
||||||
cross_attn_ln,
|
|
||||||
mlp_linear1,
|
mlp_linear1,
|
||||||
mlp_linear2,
|
mlp_linear2,
|
||||||
mlp_ln,
|
mlp_ln,
|
||||||
@ -379,11 +382,8 @@ impl ResidualAttentionBlock {
|
|||||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
||||||
let mut x = (x + attn)?;
|
let mut x = (x + attn)?;
|
||||||
// Cross-Attn
|
// Cross-Attn
|
||||||
if let Some(cross_attn_ln) = &self.cross_attn_ln {
|
if let Some((attn, ln)) = &self.cross_attn {
|
||||||
x = cross_attn_ln.forward(&x)?
|
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
|
||||||
}
|
|
||||||
if let Some(cross_attn) = &self.cross_attn {
|
|
||||||
x = cross_attn.forward(&x, xa)?
|
|
||||||
}
|
}
|
||||||
// Mlp
|
// Mlp
|
||||||
let mlp = self.mlp_linear2.forward(
|
let mlp = self.mlp_linear2.forward(
|
||||||
@ -482,7 +482,7 @@ impl TextDecoder {
|
|||||||
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
|
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
|
||||||
let blocks = (0..cfg.n_text_layer)
|
let blocks = (0..cfg.n_text_layer)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||||
@ -502,7 +502,7 @@ impl TextDecoder {
|
|||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
x = block.forward(&x, Some(xa))?
|
x = block.forward(&x, Some(xa))?;
|
||||||
}
|
}
|
||||||
let x = self.ln.forward(&x)?;
|
let x = self.ln.forward(&x)?;
|
||||||
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
||||||
|
Reference in New Issue
Block a user