mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium * Tweak the comments formatting. * Proper error message. * Cosmetic tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -75,14 +75,19 @@ struct Args {
|
||||
#[arg(long)]
|
||||
num_inference_steps: Option<usize>,
|
||||
|
||||
// CFG scale.
|
||||
/// CFG scale.
|
||||
#[arg(long)]
|
||||
cfg_scale: Option<f64>,
|
||||
|
||||
// Time shift factor (alpha).
|
||||
/// Time shift factor (alpha).
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
time_shift: f64,
|
||||
|
||||
/// Use Skip Layer Guidance (SLG) for the sampling.
|
||||
/// Currently only supports Stable Diffusion 3.5 Medium.
|
||||
#[arg(long)]
|
||||
use_slg: bool,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long)]
|
||||
seed: Option<u64>,
|
||||
@ -105,6 +110,7 @@ fn main() -> Result<()> {
|
||||
time_shift,
|
||||
seed,
|
||||
which,
|
||||
use_slg,
|
||||
} = Args::parse();
|
||||
|
||||
let _guard = if tracing {
|
||||
@ -211,6 +217,22 @@ fn main() -> Result<()> {
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
|
||||
let slg_config = if use_slg {
|
||||
match which {
|
||||
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
|
||||
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
|
||||
scale: 2.5,
|
||||
start: 0.01,
|
||||
end: 0.2,
|
||||
layers: vec![7, 8, 9],
|
||||
}),
|
||||
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let x = {
|
||||
let mmdit = MMDiT::new(
|
||||
@ -227,6 +249,7 @@ fn main() -> Result<()> {
|
||||
time_shift,
|
||||
height,
|
||||
width,
|
||||
slg_config,
|
||||
)?
|
||||
};
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
|
Reference in New Issue
Block a user