mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Merge remote-tracking branch 'origin/main' into faster-layer-norm
This commit is contained in:
@ -75,14 +75,19 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
num_inference_steps: Option<usize>,
|
num_inference_steps: Option<usize>,
|
||||||
|
|
||||||
// CFG scale.
|
/// CFG scale.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cfg_scale: Option<f64>,
|
cfg_scale: Option<f64>,
|
||||||
|
|
||||||
// Time shift factor (alpha).
|
/// Time shift factor (alpha).
|
||||||
#[arg(long, default_value_t = 3.0)]
|
#[arg(long, default_value_t = 3.0)]
|
||||||
time_shift: f64,
|
time_shift: f64,
|
||||||
|
|
||||||
|
/// Use Skip Layer Guidance (SLG) for the sampling.
|
||||||
|
/// Currently only supports Stable Diffusion 3.5 Medium.
|
||||||
|
#[arg(long)]
|
||||||
|
use_slg: bool,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
seed: Option<u64>,
|
seed: Option<u64>,
|
||||||
@ -105,6 +110,7 @@ fn main() -> Result<()> {
|
|||||||
time_shift,
|
time_shift,
|
||||||
seed,
|
seed,
|
||||||
which,
|
which,
|
||||||
|
use_slg,
|
||||||
} = Args::parse();
|
} = Args::parse();
|
||||||
|
|
||||||
let _guard = if tracing {
|
let _guard = if tracing {
|
||||||
@ -211,6 +217,22 @@ fn main() -> Result<()> {
|
|||||||
if let Some(seed) = seed {
|
if let Some(seed) = seed {
|
||||||
device.set_seed(seed)?;
|
device.set_seed(seed)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let slg_config = if use_slg {
|
||||||
|
match which {
|
||||||
|
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
|
||||||
|
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
|
||||||
|
scale: 2.5,
|
||||||
|
start: 0.01,
|
||||||
|
end: 0.2,
|
||||||
|
layers: vec![7, 8, 9],
|
||||||
|
}),
|
||||||
|
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let x = {
|
let x = {
|
||||||
let mmdit = MMDiT::new(
|
let mmdit = MMDiT::new(
|
||||||
@ -227,6 +249,7 @@ fn main() -> Result<()> {
|
|||||||
time_shift,
|
time_shift,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
|
slg_config,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
let dt = start_time.elapsed().as_secs_f32();
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
|
@ -1,8 +1,15 @@
|
|||||||
use anyhow::{Ok, Result};
|
use anyhow::{Ok, Result};
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
|
||||||
use candle_transformers::models::flux;
|
use candle_transformers::models::flux;
|
||||||
use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
|
use candle_transformers::models::mmdit::model::MMDiT;
|
||||||
|
|
||||||
|
pub struct SkipLayerGuidanceConfig {
|
||||||
|
pub scale: f64,
|
||||||
|
pub start: f64,
|
||||||
|
pub end: f64,
|
||||||
|
pub layers: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn euler_sample(
|
pub fn euler_sample(
|
||||||
@ -14,6 +21,7 @@ pub fn euler_sample(
|
|||||||
time_shift: f64,
|
time_shift: f64,
|
||||||
height: usize,
|
height: usize,
|
||||||
width: usize,
|
width: usize,
|
||||||
|
slg_config: Option<SkipLayerGuidanceConfig>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
|
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
|
||||||
let sigmas = (0..=num_inference_steps)
|
let sigmas = (0..=num_inference_steps)
|
||||||
@ -22,7 +30,7 @@ pub fn euler_sample(
|
|||||||
.map(|x| time_snr_shift(time_shift, x))
|
.map(|x| time_snr_shift(time_shift, x))
|
||||||
.collect::<Vec<f64>>();
|
.collect::<Vec<f64>>();
|
||||||
|
|
||||||
for window in sigmas.windows(2) {
|
for (step, window) in sigmas.windows(2).enumerate() {
|
||||||
let (s_curr, s_prev) = match window {
|
let (s_curr, s_prev) = match window {
|
||||||
[a, b] => (a, b),
|
[a, b] => (a, b),
|
||||||
_ => continue,
|
_ => continue,
|
||||||
@ -34,8 +42,28 @@ pub fn euler_sample(
|
|||||||
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
|
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
|
||||||
y,
|
y,
|
||||||
context,
|
context,
|
||||||
|
None,
|
||||||
)?;
|
)?;
|
||||||
x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
|
|
||||||
|
let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
|
||||||
|
|
||||||
|
if let Some(slg_config) = slg_config.as_ref() {
|
||||||
|
if (num_inference_steps as f64) * slg_config.start < (step as f64)
|
||||||
|
&& (step as f64) < (num_inference_steps as f64) * slg_config.end
|
||||||
|
{
|
||||||
|
let slg_noise_pred = mmdit.forward(
|
||||||
|
&x,
|
||||||
|
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
|
||||||
|
&y.i(..1)?,
|
||||||
|
&context.i(..1)?,
|
||||||
|
Some(&slg_config.layers),
|
||||||
|
)?;
|
||||||
|
guidance = (guidance
|
||||||
|
+ (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
x = (x + (guidance * (*s_prev - *s_curr))?)?;
|
||||||
}
|
}
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
@ -130,7 +130,14 @@ impl MMDiT {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
t: &Tensor,
|
||||||
|
y: &Tensor,
|
||||||
|
context: &Tensor,
|
||||||
|
skip_layers: Option<&[usize]>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
// Following the convention of the ComfyUI implementation.
|
// Following the convention of the ComfyUI implementation.
|
||||||
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
|
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
|
||||||
//
|
//
|
||||||
@ -150,7 +157,7 @@ impl MMDiT {
|
|||||||
let c = (c + y)?;
|
let c = (c + y)?;
|
||||||
let context = self.context_embedder.forward(context)?;
|
let context = self.context_embedder.forward(context)?;
|
||||||
|
|
||||||
let x = self.core.forward(&context, &x, &c)?;
|
let x = self.core.forward(&context, &x, &c, skip_layers)?;
|
||||||
let x = self.unpatchifier.unpatchify(&x, h, w)?;
|
let x = self.unpatchifier.unpatchify(&x, h, w)?;
|
||||||
x.narrow(2, 0, h)?.narrow(3, 0, w)
|
x.narrow(2, 0, h)?.narrow(3, 0, w)
|
||||||
}
|
}
|
||||||
@ -211,9 +218,20 @@ impl MMDiTCore {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
context: &Tensor,
|
||||||
|
x: &Tensor,
|
||||||
|
c: &Tensor,
|
||||||
|
skip_layers: Option<&[usize]>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let (mut context, mut x) = (context.clone(), x.clone());
|
let (mut context, mut x) = (context.clone(), x.clone());
|
||||||
for joint_block in &self.joint_blocks {
|
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
|
||||||
|
if let Some(skip_layers) = &skip_layers {
|
||||||
|
if skip_layers.contains(&i) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
(context, x) = joint_block.forward(&context, &x, c)?;
|
(context, x) = joint_block.forward(&context, &x, c)?;
|
||||||
}
|
}
|
||||||
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
|
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
|
||||||
|
Reference in New Issue
Block a user