Stable Diffusion Turbo Support (#1395)

* Add support for SD Turbo

* Set Leading as default in euler_ancestral discrete

* Use the appropriate default values for n_steps and guidance_scale.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Edwin Cheng
2023-12-03 15:37:10 +08:00
committed by GitHub
parent dd40edfe73
commit 37bf1ed012
6 changed files with 259 additions and 67 deletions

View File

@ -7,7 +7,9 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing};
use super::schedulers::{
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@ -48,6 +50,12 @@ impl Default for DDIMSchedulerConfig {
}
}
impl SchedulerConfig for DDIMSchedulerConfig {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
}
}
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
@ -63,7 +71,7 @@ impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = match config.timestep_spacing {
TimestepSpacing::Leading => (0..(inference_steps))
@ -115,19 +123,11 @@ impl DDIMScheduler {
config,
})
}
}
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@ -186,7 +186,17 @@ impl DDIMScheduler {
}
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@ -197,7 +207,7 @@ impl DDIMScheduler {
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
pub fn init_noise_sigma(&self) -> f64 {
fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}

View File

@ -8,7 +8,10 @@
///
/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
use super::{
schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing},
schedulers::{
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
TimestepSpacing,
},
utils::interp,
};
use candle::{bail, Error, Result, Tensor};
@ -43,11 +46,20 @@ impl Default for EulerAncestralDiscreteSchedulerConfig {
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
timestep_spacing: TimestepSpacing::Trailing,
timestep_spacing: TimestepSpacing::Leading,
}
}
}
impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
Ok(Box::new(EulerAncestralDiscreteScheduler::new(
inference_steps,
*self,
)?))
}
}
/// The EulerAncestral Discrete scheduler.
#[derive(Debug, Clone)]
pub struct EulerAncestralDiscreteScheduler {
@ -138,8 +150,10 @@ impl EulerAncestralDiscreteScheduler {
config,
})
}
}
pub fn timesteps(&self) -> &[usize] {
impl Scheduler for EulerAncestralDiscreteScheduler {
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
@ -147,7 +161,7 @@ impl EulerAncestralDiscreteScheduler {
/// depending on the current timestep.
///
/// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
Some(i) => i,
None => bail!("timestep out of this schedulers bounds: {timestep}"),
@ -162,7 +176,7 @@ impl EulerAncestralDiscreteScheduler {
}
/// Performs a backward step during inference.
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
@ -197,7 +211,7 @@ impl EulerAncestralDiscreteScheduler {
prev_sample + noise * sigma_up
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
@ -212,7 +226,7 @@ impl EulerAncestralDiscreteScheduler {
original + (noise * *sigma)?
}
pub fn init_noise_sigma(&self) -> f64 {
fn init_noise_sigma(&self) -> f64 {
match self.config.timestep_spacing {
TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),

View File

@ -11,9 +11,13 @@ pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
use std::sync::Arc;
use candle::{DType, Device, Result};
use candle_nn as nn;
use self::schedulers::{Scheduler, SchedulerConfig};
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
@ -22,7 +26,7 @@ pub struct StableDiffusionConfig {
pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
scheduler: Arc<dyn SchedulerConfig>,
}
impl StableDiffusionConfig {
@ -76,13 +80,18 @@ impl StableDiffusionConfig {
512
};
Self {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type: schedulers::PredictionType::Epsilon,
..Default::default()
});
StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
scheduler: Default::default(),
scheduler,
unet,
}
}
@ -125,10 +134,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -144,7 +153,7 @@ impl StableDiffusionConfig {
768
};
Self {
StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
@ -206,10 +215,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -225,6 +234,76 @@ impl StableDiffusionConfig {
1024
};
StableDiffusionConfig {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
fn sdxl_turbo_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = Arc::new(
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
prediction_type,
timestep_spacing: schedulers::TimestepSpacing::Trailing,
..Default::default()
},
);
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
Self {
width,
height,
@ -250,6 +329,20 @@ impl StableDiffusionConfig {
)
}
pub fn sdxl_turbo(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
Self::sdxl_turbo_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
schedulers::PredictionType::Epsilon,
)
}
pub fn ssd1b(
sliced_attention_size: Option<usize>,
height: Option<usize>,
@ -286,9 +379,9 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
..Default::default()
};
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -348,8 +441,8 @@ impl StableDiffusionConfig {
Ok(unet)
}
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
self.scheduler.build(n_steps)
}
}

View File

@ -3,9 +3,25 @@
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
use candle::{Result, Tensor};
pub trait SchedulerConfig: std::fmt::Debug {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
}
/// This trait represents a scheduler for the diffusion process.
pub trait Scheduler {
fn timesteps(&self) -> &[usize];
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
fn init_noise_sigma(&self) -> f64;
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]