mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium * Tweak the comments formatting. * Proper error message. * Cosmetic tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -130,7 +130,14 @@ impl MMDiT {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
t: &Tensor,
|
||||
y: &Tensor,
|
||||
context: &Tensor,
|
||||
skip_layers: Option<&[usize]>,
|
||||
) -> Result<Tensor> {
|
||||
// Following the convention of the ComfyUI implementation.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
|
||||
//
|
||||
@ -150,7 +157,7 @@ impl MMDiT {
|
||||
let c = (c + y)?;
|
||||
let context = self.context_embedder.forward(context)?;
|
||||
|
||||
let x = self.core.forward(&context, &x, &c)?;
|
||||
let x = self.core.forward(&context, &x, &c, skip_layers)?;
|
||||
let x = self.unpatchifier.unpatchify(&x, h, w)?;
|
||||
x.narrow(2, 0, h)?.narrow(3, 0, w)
|
||||
}
|
||||
@ -211,9 +218,20 @@ impl MMDiTCore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
context: &Tensor,
|
||||
x: &Tensor,
|
||||
c: &Tensor,
|
||||
skip_layers: Option<&[usize]>,
|
||||
) -> Result<Tensor> {
|
||||
let (mut context, mut x) = (context.clone(), x.clone());
|
||||
for joint_block in &self.joint_blocks {
|
||||
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
|
||||
if let Some(skip_layers) = &skip_layers {
|
||||
if skip_layers.contains(&i) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
(context, x) = joint_block.forward(&context, &x, c)?;
|
||||
}
|
||||
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
|
||||
|
Reference in New Issue
Block a user