mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium * Tweak the comments formatting. * Proper error message. * Cosmetic tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -75,14 +75,19 @@ struct Args {
|
||||
#[arg(long)]
|
||||
num_inference_steps: Option<usize>,
|
||||
|
||||
// CFG scale.
|
||||
/// CFG scale.
|
||||
#[arg(long)]
|
||||
cfg_scale: Option<f64>,
|
||||
|
||||
// Time shift factor (alpha).
|
||||
/// Time shift factor (alpha).
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
time_shift: f64,
|
||||
|
||||
/// Use Skip Layer Guidance (SLG) for the sampling.
|
||||
/// Currently only supports Stable Diffusion 3.5 Medium.
|
||||
#[arg(long)]
|
||||
use_slg: bool,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long)]
|
||||
seed: Option<u64>,
|
||||
@ -105,6 +110,7 @@ fn main() -> Result<()> {
|
||||
time_shift,
|
||||
seed,
|
||||
which,
|
||||
use_slg,
|
||||
} = Args::parse();
|
||||
|
||||
let _guard = if tracing {
|
||||
@ -211,6 +217,22 @@ fn main() -> Result<()> {
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
|
||||
let slg_config = if use_slg {
|
||||
match which {
|
||||
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
|
||||
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
|
||||
scale: 2.5,
|
||||
start: 0.01,
|
||||
end: 0.2,
|
||||
layers: vec![7, 8, 9],
|
||||
}),
|
||||
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let x = {
|
||||
let mmdit = MMDiT::new(
|
||||
@ -227,6 +249,7 @@ fn main() -> Result<()> {
|
||||
time_shift,
|
||||
height,
|
||||
width,
|
||||
slg_config,
|
||||
)?
|
||||
};
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
|
@ -1,8 +1,15 @@
|
||||
use anyhow::{Ok, Result};
|
||||
use candle::{DType, Tensor};
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
|
||||
use candle_transformers::models::flux;
|
||||
use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
|
||||
use candle_transformers::models::mmdit::model::MMDiT;
|
||||
|
||||
pub struct SkipLayerGuidanceConfig {
|
||||
pub scale: f64,
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub layers: Vec<usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn euler_sample(
|
||||
@ -14,6 +21,7 @@ pub fn euler_sample(
|
||||
time_shift: f64,
|
||||
height: usize,
|
||||
width: usize,
|
||||
slg_config: Option<SkipLayerGuidanceConfig>,
|
||||
) -> Result<Tensor> {
|
||||
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
|
||||
let sigmas = (0..=num_inference_steps)
|
||||
@ -22,7 +30,7 @@ pub fn euler_sample(
|
||||
.map(|x| time_snr_shift(time_shift, x))
|
||||
.collect::<Vec<f64>>();
|
||||
|
||||
for window in sigmas.windows(2) {
|
||||
for (step, window) in sigmas.windows(2).enumerate() {
|
||||
let (s_curr, s_prev) = match window {
|
||||
[a, b] => (a, b),
|
||||
_ => continue,
|
||||
@ -34,8 +42,28 @@ pub fn euler_sample(
|
||||
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
|
||||
y,
|
||||
context,
|
||||
None,
|
||||
)?;
|
||||
x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
|
||||
|
||||
let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
|
||||
|
||||
if let Some(slg_config) = slg_config.as_ref() {
|
||||
if (num_inference_steps as f64) * slg_config.start < (step as f64)
|
||||
&& (step as f64) < (num_inference_steps as f64) * slg_config.end
|
||||
{
|
||||
let slg_noise_pred = mmdit.forward(
|
||||
&x,
|
||||
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
|
||||
&y.i(..1)?,
|
||||
&context.i(..1)?,
|
||||
Some(&slg_config.layers),
|
||||
)?;
|
||||
guidance = (guidance
|
||||
+ (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
|
||||
}
|
||||
}
|
||||
|
||||
x = (x + (guidance * (*s_prev - *s_curr))?)?;
|
||||
}
|
||||
Ok(x)
|
||||
}
|
||||
|
@ -130,7 +130,14 @@ impl MMDiT {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
t: &Tensor,
|
||||
y: &Tensor,
|
||||
context: &Tensor,
|
||||
skip_layers: Option<&[usize]>,
|
||||
) -> Result<Tensor> {
|
||||
// Following the convention of the ComfyUI implementation.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
|
||||
//
|
||||
@ -150,7 +157,7 @@ impl MMDiT {
|
||||
let c = (c + y)?;
|
||||
let context = self.context_embedder.forward(context)?;
|
||||
|
||||
let x = self.core.forward(&context, &x, &c)?;
|
||||
let x = self.core.forward(&context, &x, &c, skip_layers)?;
|
||||
let x = self.unpatchifier.unpatchify(&x, h, w)?;
|
||||
x.narrow(2, 0, h)?.narrow(3, 0, w)
|
||||
}
|
||||
@ -211,9 +218,20 @@ impl MMDiTCore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(
|
||||
&self,
|
||||
context: &Tensor,
|
||||
x: &Tensor,
|
||||
c: &Tensor,
|
||||
skip_layers: Option<&[usize]>,
|
||||
) -> Result<Tensor> {
|
||||
let (mut context, mut x) = (context.clone(), x.clone());
|
||||
for joint_block in &self.joint_blocks {
|
||||
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
|
||||
if let Some(skip_layers) = &skip_layers {
|
||||
if skip_layers.contains(&i) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
(context, x) = joint_block.forward(&context, &x, c)?;
|
||||
}
|
||||
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
|
||||
|
Reference in New Issue
Block a user