Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)

* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Czxck001
2024-11-01 10:10:40 -07:00
committed by GitHub
parent 7ac0de15a9
commit 530ab96036
3 changed files with 79 additions and 10 deletions

View File

@ -130,7 +130,14 @@ impl MMDiT {
})
}
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
x: &Tensor,
t: &Tensor,
y: &Tensor,
context: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
@ -150,7 +157,7 @@ impl MMDiT {
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
let x = self.core.forward(&context, &x, &c)?;
let x = self.core.forward(&context, &x, &c, skip_layers)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
@ -211,9 +218,20 @@ impl MMDiTCore {
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
context: &Tensor,
x: &Tensor,
c: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
for joint_block in &self.joint_blocks {
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
if let Some(skip_layers) = &skip_layers {
if skip_layers.contains(&i) {
continue;
}
}
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;