mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Stable Diffusion Turbo Support (#1395)
* Add support for SD Turbo * Set Leading as default in euler_ancestral discrete * Use the appropriate default values for n_steps and guidance_scale. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -7,7 +7,9 @@
|
||||
//!
|
||||
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
|
||||
//! https://arxiv.org/abs/2010.02502
|
||||
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing};
|
||||
use super::schedulers::{
|
||||
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
|
||||
};
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// The configuration for the DDIM scheduler.
|
||||
@ -48,6 +50,12 @@ impl Default for DDIMSchedulerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl SchedulerConfig for DDIMSchedulerConfig {
|
||||
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
|
||||
Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
|
||||
}
|
||||
}
|
||||
|
||||
/// The DDIM scheduler.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DDIMScheduler {
|
||||
@ -63,7 +71,7 @@ impl DDIMScheduler {
|
||||
/// Creates a new DDIM scheduler given the number of steps to be
|
||||
/// used for inference as well as the number of steps that was used
|
||||
/// during training.
|
||||
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
|
||||
fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
|
||||
let step_ratio = config.train_timesteps / inference_steps;
|
||||
let timesteps: Vec<usize> = match config.timestep_spacing {
|
||||
TimestepSpacing::Leading => (0..(inference_steps))
|
||||
@ -115,19 +123,11 @@ impl DDIMScheduler {
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
|
||||
Ok(sample)
|
||||
}
|
||||
|
||||
impl Scheduler for DDIMScheduler {
|
||||
/// Performs a backward step during inference.
|
||||
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
@ -186,7 +186,17 @@ impl DDIMScheduler {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
|
||||
Ok(sample)
|
||||
}
|
||||
|
||||
fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
@ -197,7 +207,7 @@ impl DDIMScheduler {
|
||||
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
fn init_noise_sigma(&self) -> f64 {
|
||||
self.init_noise_sigma
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,10 @@
|
||||
///
|
||||
/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
|
||||
use super::{
|
||||
schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing},
|
||||
schedulers::{
|
||||
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
|
||||
TimestepSpacing,
|
||||
},
|
||||
utils::interp,
|
||||
};
|
||||
use candle::{bail, Error, Result, Tensor};
|
||||
@ -43,11 +46,20 @@ impl Default for EulerAncestralDiscreteSchedulerConfig {
|
||||
steps_offset: 1,
|
||||
prediction_type: PredictionType::Epsilon,
|
||||
train_timesteps: 1000,
|
||||
timestep_spacing: TimestepSpacing::Trailing,
|
||||
timestep_spacing: TimestepSpacing::Leading,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
|
||||
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
|
||||
Ok(Box::new(EulerAncestralDiscreteScheduler::new(
|
||||
inference_steps,
|
||||
*self,
|
||||
)?))
|
||||
}
|
||||
}
|
||||
|
||||
/// The EulerAncestral Discrete scheduler.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EulerAncestralDiscreteScheduler {
|
||||
@ -138,8 +150,10 @@ impl EulerAncestralDiscreteScheduler {
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timesteps(&self) -> &[usize] {
|
||||
impl Scheduler for EulerAncestralDiscreteScheduler {
|
||||
fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
@ -147,7 +161,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
/// depending on the current timestep.
|
||||
///
|
||||
/// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
|
||||
pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
|
||||
Some(i) => i,
|
||||
None => bail!("timestep out of this schedulers bounds: {timestep}"),
|
||||
@ -162,7 +176,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
}
|
||||
|
||||
/// Performs a backward step during inference.
|
||||
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let step_index = self
|
||||
.timesteps
|
||||
.iter()
|
||||
@ -197,7 +211,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
prev_sample + noise * sigma_up
|
||||
}
|
||||
|
||||
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let step_index = self
|
||||
.timesteps
|
||||
.iter()
|
||||
@ -212,7 +226,7 @@ impl EulerAncestralDiscreteScheduler {
|
||||
original + (noise * *sigma)?
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
fn init_noise_sigma(&self) -> f64 {
|
||||
match self.config.timestep_spacing {
|
||||
TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
|
||||
TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
|
||||
|
@ -11,9 +11,13 @@ pub mod unet_2d_blocks;
|
||||
pub mod utils;
|
||||
pub mod vae;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Result};
|
||||
use candle_nn as nn;
|
||||
|
||||
use self::schedulers::{Scheduler, SchedulerConfig};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StableDiffusionConfig {
|
||||
pub width: usize,
|
||||
@ -22,7 +26,7 @@ pub struct StableDiffusionConfig {
|
||||
pub clip2: Option<clip::Config>,
|
||||
autoencoder: vae::AutoEncoderKLConfig,
|
||||
unet: unet_2d::UNet2DConditionModelConfig,
|
||||
scheduler: ddim::DDIMSchedulerConfig,
|
||||
scheduler: Arc<dyn SchedulerConfig>,
|
||||
}
|
||||
|
||||
impl StableDiffusionConfig {
|
||||
@ -76,13 +80,18 @@ impl StableDiffusionConfig {
|
||||
512
|
||||
};
|
||||
|
||||
Self {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type: schedulers::PredictionType::Epsilon,
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v1_5(),
|
||||
clip2: None,
|
||||
autoencoder,
|
||||
scheduler: Default::default(),
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
@ -125,10 +134,10 @@ impl StableDiffusionConfig {
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
..Default::default()
|
||||
};
|
||||
});
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -144,7 +153,7 @@ impl StableDiffusionConfig {
|
||||
768
|
||||
};
|
||||
|
||||
Self {
|
||||
StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v2_1(),
|
||||
@ -206,10 +215,10 @@ impl StableDiffusionConfig {
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
..Default::default()
|
||||
};
|
||||
});
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -225,6 +234,76 @@ impl StableDiffusionConfig {
|
||||
1024
|
||||
};
|
||||
|
||||
StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::sdxl(),
|
||||
clip2: Some(clip::Config::sdxl2()),
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
fn sdxl_turbo_(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
prediction_type: schedulers::PredictionType,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, None, 5),
|
||||
bc(640, Some(2), 10),
|
||||
bc(1280, Some(10), 20),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 2048,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: true,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = Arc::new(
|
||||
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
|
||||
prediction_type,
|
||||
timestep_spacing: schedulers::TimestepSpacing::Trailing,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
@ -250,6 +329,20 @@ impl StableDiffusionConfig {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sdxl_turbo(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
Self::sdxl_turbo_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
|
||||
schedulers::PredictionType::Epsilon,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn ssd1b(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
@ -286,9 +379,9 @@ impl StableDiffusionConfig {
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
..Default::default()
|
||||
};
|
||||
});
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -348,8 +441,8 @@ impl StableDiffusionConfig {
|
||||
Ok(unet)
|
||||
}
|
||||
|
||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
|
||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
|
||||
self.scheduler.build(n_steps)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,25 @@
|
||||
//!
|
||||
//! Noise schedulers can be used to set the trade-off between
|
||||
//! inference speed and quality.
|
||||
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub trait SchedulerConfig: std::fmt::Debug {
|
||||
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
|
||||
}
|
||||
|
||||
/// This trait represents a scheduler for the diffusion process.
|
||||
pub trait Scheduler {
|
||||
fn timesteps(&self) -> &[usize];
|
||||
|
||||
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
|
||||
|
||||
fn init_noise_sigma(&self) -> f64;
|
||||
|
||||
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
|
||||
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
/// This represents how beta ranges from its minimum value to the maximum
|
||||
/// during training.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
Reference in New Issue
Block a user