Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)

* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Czxck001
2024-11-01 10:10:40 -07:00
committed by GitHub
parent 7ac0de15a9
commit 530ab96036
3 changed files with 79 additions and 10 deletions

View File

@ -75,14 +75,19 @@ struct Args {
#[arg(long)]
num_inference_steps: Option<usize>,
// CFG scale.
/// CFG scale.
#[arg(long)]
cfg_scale: Option<f64>,
// Time shift factor (alpha).
/// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,
/// Use Skip Layer Guidance (SLG) for the sampling.
/// Currently only supports Stable Diffusion 3.5 Medium.
#[arg(long)]
use_slg: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
@ -105,6 +110,7 @@ fn main() -> Result<()> {
time_shift,
seed,
which,
use_slg,
} = Args::parse();
let _guard = if tracing {
@ -211,6 +217,22 @@ fn main() -> Result<()> {
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let slg_config = if use_slg {
match which {
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
scale: 2.5,
start: 0.01,
end: 0.2,
layers: vec![7, 8, 9],
}),
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
}
} else {
None
};
let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new(
@ -227,6 +249,7 @@ fn main() -> Result<()> {
time_shift,
height,
width,
slg_config,
)?
};
let dt = start_time.elapsed().as_secs_f32();

View File

@ -1,8 +1,15 @@
use anyhow::{Ok, Result};
use candle::{DType, Tensor};
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::flux;
use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
use candle_transformers::models::mmdit::model::MMDiT;
pub struct SkipLayerGuidanceConfig {
pub scale: f64,
pub start: f64,
pub end: f64,
pub layers: Vec<usize>,
}
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
@ -14,6 +21,7 @@ pub fn euler_sample(
time_shift: f64,
height: usize,
width: usize,
slg_config: Option<SkipLayerGuidanceConfig>,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
@ -22,7 +30,7 @@ pub fn euler_sample(
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();
for window in sigmas.windows(2) {
for (step, window) in sigmas.windows(2).enumerate() {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
@ -34,8 +42,28 @@ pub fn euler_sample(
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
None,
)?;
x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
if let Some(slg_config) = slg_config.as_ref() {
if (num_inference_steps as f64) * slg_config.start < (step as f64)
&& (step as f64) < (num_inference_steps as f64) * slg_config.end
{
let slg_noise_pred = mmdit.forward(
&x,
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
&y.i(..1)?,
&context.i(..1)?,
Some(&slg_config.layers),
)?;
guidance = (guidance
+ (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
}
}
x = (x + (guidance * (*s_prev - *s_curr))?)?;
}
Ok(x)
}

View File

@ -130,7 +130,14 @@ impl MMDiT {
})
}
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
x: &Tensor,
t: &Tensor,
y: &Tensor,
context: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
@ -150,7 +157,7 @@ impl MMDiT {
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
let x = self.core.forward(&context, &x, &c)?;
let x = self.core.forward(&context, &x, &c, skip_layers)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
@ -211,9 +218,20 @@ impl MMDiTCore {
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
context: &Tensor,
x: &Tensor,
c: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
for joint_block in &self.joint_blocks {
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
if let Some(skip_layers) = &skip_layers {
if skip_layers.contains(&i) {
continue;
}
}
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;