mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn * further adjust attn and joint_attn * add mmdit-x support * support sd3.5-medium in the example * update README.md
This commit is contained in:
@ -36,7 +36,6 @@ impl Module for LayerNormNoAffine {
|
||||
|
||||
impl DiTBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
// {'hidden_size': 1536, 'num_heads': 24}
|
||||
let norm1 = LayerNormNoAffine::new(1e-6);
|
||||
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
|
||||
let norm2 = LayerNormNoAffine::new(1e-6);
|
||||
@ -103,6 +102,117 @@ impl DiTBlock {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SelfAttnModulateIntermediates {
|
||||
gate_msa: Tensor,
|
||||
shift_mlp: Tensor,
|
||||
scale_mlp: Tensor,
|
||||
gate_mlp: Tensor,
|
||||
gate_msa2: Tensor,
|
||||
}
|
||||
|
||||
pub struct SelfAttnDiTBlock {
|
||||
norm1: LayerNormNoAffine,
|
||||
attn: AttnProjections,
|
||||
attn2: AttnProjections,
|
||||
norm2: LayerNormNoAffine,
|
||||
mlp: Mlp,
|
||||
ada_ln_modulation: nn::Sequential,
|
||||
}
|
||||
|
||||
impl SelfAttnDiTBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let norm1 = LayerNormNoAffine::new(1e-6);
|
||||
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
|
||||
let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?;
|
||||
let norm2 = LayerNormNoAffine::new(1e-6);
|
||||
let mlp_ratio = 4;
|
||||
let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
|
||||
let n_mods = 9;
|
||||
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
|
||||
hidden_size,
|
||||
n_mods * hidden_size,
|
||||
vb.pp("adaLN_modulation.1"),
|
||||
)?);
|
||||
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
attn2,
|
||||
norm2,
|
||||
mlp,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
c: &Tensor,
|
||||
) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
|
||||
let modulation = self.ada_ln_modulation.forward(c)?;
|
||||
let chunks = modulation.chunk(9, D::Minus1)?;
|
||||
let (
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = (
|
||||
chunks[0].clone(),
|
||||
chunks[1].clone(),
|
||||
chunks[2].clone(),
|
||||
chunks[3].clone(),
|
||||
chunks[4].clone(),
|
||||
chunks[5].clone(),
|
||||
chunks[6].clone(),
|
||||
chunks[7].clone(),
|
||||
chunks[8].clone(),
|
||||
);
|
||||
|
||||
let norm_x = self.norm1.forward(x)?;
|
||||
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
|
||||
let qkv = self.attn.pre_attention(&modulated_x)?;
|
||||
|
||||
let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
|
||||
let qkv2 = self.attn2.pre_attention(&modulated_x2)?;
|
||||
|
||||
Ok((
|
||||
qkv,
|
||||
qkv2,
|
||||
SelfAttnModulateIntermediates {
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
pub fn post_attention(
|
||||
&self,
|
||||
attn: &Tensor,
|
||||
attn2: &Tensor,
|
||||
x: &Tensor,
|
||||
mod_interm: &SelfAttnModulateIntermediates,
|
||||
) -> Result<Tensor> {
|
||||
let attn_out = self.attn.post_attention(attn)?;
|
||||
let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
|
||||
let attn_out2 = self.attn2.post_attention(attn2)?;
|
||||
let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;
|
||||
|
||||
let norm_x = self.norm2.forward(&x)?;
|
||||
let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
|
||||
let mlp_out = self.mlp.forward(&modulated_x)?;
|
||||
let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QkvOnlyDiTBlock {
|
||||
norm1: LayerNormNoAffine,
|
||||
attn: QkvOnlyAttnProjections,
|
||||
@ -190,14 +300,18 @@ fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
|
||||
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
|
||||
}
|
||||
|
||||
pub struct JointBlock {
|
||||
pub trait JointBlock {
|
||||
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>;
|
||||
}
|
||||
|
||||
pub struct MMDiTJointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: DiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl JointBlock {
|
||||
impl MMDiTJointBlock {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
@ -214,8 +328,10 @@ impl JointBlock {
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
impl JointBlock for MMDiTJointBlock {
|
||||
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
let (context_attn, x_attn) =
|
||||
@ -228,6 +344,49 @@ impl JointBlock {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMDiTXJointBlock {
|
||||
x_block: SelfAttnDiTBlock,
|
||||
context_block: DiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl MMDiTXJointBlock {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl JointBlock for MMDiTXJointBlock {
|
||||
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
let (context_attn, x_attn) =
|
||||
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
|
||||
let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?;
|
||||
let context_out =
|
||||
self.context_block
|
||||
.post_attention(&context_attn, context, &context_interm)?;
|
||||
let x_out = self
|
||||
.x_block
|
||||
.post_attention(&x_attn, &x_attn2, x, &x_interm)?;
|
||||
Ok((context_out, x_out))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContextQkvOnlyJointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: QkvOnlyDiTBlock,
|
||||
@ -309,26 +468,30 @@ fn joint_attn(
|
||||
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
|
||||
};
|
||||
|
||||
let (batch_size, seqlen, _) = qkv.q.dims3()?;
|
||||
let qkv = Qkv {
|
||||
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
v: qkv.v,
|
||||
};
|
||||
|
||||
let headdim = qkv.q.dim(D::Minus1)?;
|
||||
let softmax_scale = 1.0 / (headdim as f64).sqrt();
|
||||
|
||||
let attn = if use_flash_attn {
|
||||
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
|
||||
} else {
|
||||
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
|
||||
};
|
||||
|
||||
let attn = attn.reshape((batch_size, seqlen, ()))?;
|
||||
let seqlen = qkv.q.dim(1)?;
|
||||
let attn = attn(&qkv, num_heads, use_flash_attn)?;
|
||||
let context_qkv_seqlen = context_qkv.q.dim(1)?;
|
||||
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
|
||||
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
|
||||
|
||||
Ok((context_attn, x_attn))
|
||||
}
|
||||
|
||||
fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result<Tensor> {
|
||||
let batch_size = qkv.q.dim(0)?;
|
||||
let seqlen = qkv.q.dim(1)?;
|
||||
let qkv = Qkv {
|
||||
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
v: qkv.v.clone(),
|
||||
};
|
||||
|
||||
let headdim = qkv.q.dim(D::Minus1)?;
|
||||
let softmax_scale = 1.0 / (headdim as f64).sqrt();
|
||||
let attn = if use_flash_attn {
|
||||
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
|
||||
} else {
|
||||
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
|
||||
};
|
||||
attn.reshape((batch_size, seqlen, ()))
|
||||
}
|
||||
|
Reference in New Issue
Block a user