mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Line up the textdecoder values with the python implementation.
This commit is contained in:
@ -323,13 +323,18 @@ impl MultiHeadAttention {
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = q.shape().r3()?;
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||
}
|
||||
|
||||
fn qkv_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let (_, _, n_state) = q.shape().r3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = (q.reshape(target_dims)?.transpose(1, 2)? * scale)?;
|
||||
let k = (k.reshape(target_dims)?.transpose(1, 2)?.transpose(2, 3)? * scale)?;
|
||||
let v = v.reshape(target_dims)?.transpose(1, 2)?;
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let qk = q.matmul(&k)?;
|
||||
let w = qk.softmax(qk.rank() - 1)?;
|
||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
|
||||
@ -341,8 +346,7 @@ impl MultiHeadAttention {
|
||||
struct ResidualAttentionBlock {
|
||||
attn: MultiHeadAttention,
|
||||
attn_ln: LayerNorm,
|
||||
cross_attn: Option<MultiHeadAttention>,
|
||||
cross_attn_ln: Option<LayerNorm>,
|
||||
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
||||
mlp_linear1: Linear,
|
||||
mlp_linear2: Linear,
|
||||
mlp_ln: LayerNorm,
|
||||
@ -352,13 +356,13 @@ impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
|
||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
|
||||
let (cross_attn, cross_attn_ln) = if ca {
|
||||
let cross_attn = if ca {
|
||||
let cross_attn =
|
||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
|
||||
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
|
||||
(Some(cross_attn), Some(cross_attn_ln))
|
||||
Some((cross_attn, cross_attn_ln))
|
||||
} else {
|
||||
(None, None)
|
||||
None
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
|
||||
@ -368,7 +372,6 @@ impl ResidualAttentionBlock {
|
||||
attn,
|
||||
attn_ln,
|
||||
cross_attn,
|
||||
cross_attn_ln,
|
||||
mlp_linear1,
|
||||
mlp_linear2,
|
||||
mlp_ln,
|
||||
@ -379,11 +382,8 @@ impl ResidualAttentionBlock {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
||||
let mut x = (x + attn)?;
|
||||
// Cross-Attn
|
||||
if let Some(cross_attn_ln) = &self.cross_attn_ln {
|
||||
x = cross_attn_ln.forward(&x)?
|
||||
}
|
||||
if let Some(cross_attn) = &self.cross_attn {
|
||||
x = cross_attn.forward(&x, xa)?
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
|
||||
}
|
||||
// Mlp
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
@ -482,7 +482,7 @@ impl TextDecoder {
|
||||
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
|
||||
let blocks = (0..cfg.n_text_layer)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||
@ -502,7 +502,7 @@ impl TextDecoder {
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, Some(xa))?
|
||||
x = block.forward(&x, Some(xa))?;
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
||||
|
Reference in New Issue
Block a user