Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)

* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Czxck001
2024-11-01 10:10:40 -07:00
committed by GitHub
parent 7ac0de15a9
commit 530ab96036
3 changed files with 79 additions and 10 deletions

View File

@ -75,14 +75,19 @@ struct Args {
#[arg(long)]
num_inference_steps: Option<usize>,
// CFG scale.
/// CFG scale.
#[arg(long)]
cfg_scale: Option<f64>,
// Time shift factor (alpha).
/// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,
/// Use Skip Layer Guidance (SLG) for the sampling.
/// Currently only supports Stable Diffusion 3.5 Medium.
#[arg(long)]
use_slg: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
@ -105,6 +110,7 @@ fn main() -> Result<()> {
time_shift,
seed,
which,
use_slg,
} = Args::parse();
let _guard = if tracing {
@ -211,6 +217,22 @@ fn main() -> Result<()> {
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let slg_config = if use_slg {
match which {
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
scale: 2.5,
start: 0.01,
end: 0.2,
layers: vec![7, 8, 9],
}),
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
}
} else {
None
};
let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new(
@ -227,6 +249,7 @@ fn main() -> Result<()> {
time_shift,
height,
width,
slg_config,
)?
};
let dt = start_time.elapsed().as_secs_f32();