Add the ddim scheduler. (#330)

This commit is contained in:
Laurent Mazare
2023-08-06 21:44:00 +02:00
committed by GitHub
parent d34039e352
commit 1c062bf06b
5 changed files with 445 additions and 0 deletions

View File

@ -0,0 +1,181 @@
#![allow(dead_code)]
//! # Denoising Diffusion Implicit Models
//!
//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
//! generative process is the reverse of a Markovian process, DDIM generalizes
//! this to non-Markovian guidance.
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
#[derive(Debug, Clone, Copy)]
pub struct DDIMSchedulerConfig {
/// The value of beta at the beginning of training.
pub beta_start: f64,
/// The value of beta at the end of training.
pub beta_end: f64,
/// How beta evolved during training.
pub beta_schedule: BetaSchedule,
/// The amount of noise to be added at each step.
pub eta: f64,
/// Adjust the indexes of the inference schedule by this value.
pub steps_offset: usize,
/// prediction type of the scheduler function, one of `epsilon` (predicting
/// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
/// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
pub prediction_type: PredictionType,
/// number of diffusion steps used to train the model
pub train_timesteps: usize,
}
impl Default for DDIMSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085f64,
beta_end: 0.012f64,
beta_schedule: BetaSchedule::ScaledLinear,
eta: 0.,
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
}
}
}
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
timesteps: Vec<usize>,
alphas_cumprod: Vec<f64>,
step_ratio: usize,
init_noise_sigma: f64,
pub config: DDIMSchedulerConfig,
}
// clip_sample: False, set_alpha_to_one: False
impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = (0..(inference_steps))
.map(|s| s * step_ratio + config.steps_offset)
.rev()
.collect();
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => crate::utils::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps,
)?
.sqr()?,
BetaSchedule::Linear => {
crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
};
let betas = betas.to_vec1::<f64>()?;
let mut alphas_cumprod = Vec::with_capacity(betas.len());
for &beta in betas.iter() {
let alpha = 1.0 - beta;
alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
}
Ok(Self {
alphas_cumprod,
timesteps,
step_ratio,
init_noise_sigma: 1.,
config,
})
}
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
sample
}
/// Performs a backward step during inference.
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
timestep
};
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
let prev_timestep = if timestep > self.step_ratio {
timestep - self.step_ratio
} else {
0
};
let alpha_prod_t = self.alphas_cumprod[timestep];
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
let beta_prod_t = 1. - alpha_prod_t;
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
PredictionType::Epsilon => {
let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
* (1. / alpha_prod_t.sqrt()))?;
(pred_original_sample, model_output.clone())
}
PredictionType::VPrediction => {
let pred_original_sample =
((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
let pred_epsilon =
((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
(pred_original_sample, pred_epsilon)
}
PredictionType::Sample => {
let pred_original_sample = model_output.clone();
let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
* (1. / beta_prod_t.sqrt()))?;
(pred_original_sample, pred_epsilon)
}
};
let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
let std_dev_t = self.config.eta * variance.sqrt();
let pred_sample_direction =
(pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
let prev_sample =
((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
if self.config.eta > 0. {
&prev_sample
+ Tensor::randn(
0f32,
std_dev_t as f32,
prev_sample.shape(),
prev_sample.device(),
)?
} else {
Ok(prev_sample)
}
}
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
timestep
};
let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}

View File

@ -3,8 +3,11 @@ extern crate intel_mkl_src;
mod attention;
mod clip;
mod ddim;
mod embeddings;
mod resnet;
mod schedulers;
mod stable_diffusion;
mod unet_2d;
mod unet_2d_blocks;
mod utils;

View File

@ -0,0 +1,45 @@
#![allow(dead_code)]
//! # Diffusion pipelines and models
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
use candle::{Result, Tensor};
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]
pub enum BetaSchedule {
/// Linear interpolation.
Linear,
/// Linear interpolation of the square root of beta.
ScaledLinear,
/// Glide cosine schedule
SquaredcosCapV2,
}
#[derive(Debug, Clone, Copy)]
pub enum PredictionType {
Epsilon,
VPrediction,
Sample,
}
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
/// `(1-beta)` over time from `t = [0,1]`.
///
/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
/// up to that part of the diffusion process.
pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
let alpha_bar = |time_step: usize| {
f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
};
let mut betas = Vec::with_capacity(num_diffusion_timesteps);
for i in 0..num_diffusion_timesteps {
let t1 = i / num_diffusion_timesteps;
let t2 = (i + 1) / num_diffusion_timesteps;
betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
}
let betas_len = betas.len();
Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
}

View File

@ -0,0 +1,212 @@
#![allow(dead_code)]
use crate::schedulers::PredictionType;
use crate::{clip, ddim, unet_2d, vae};
use candle::{DType, Device, Result};
use candle_nn as nn;
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
}
impl StableDiffusionConfig {
pub fn v1_5(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, true, 8),
bc(640, true, 8),
bc(1280, true, 8),
bc(1280, false, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: false,
};
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
Self {
width,
height,
clip: clip::Config::v1_5(),
autoencoder,
scheduler: Default::default(),
unet,
}
}
fn v2_1_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, true, 5),
bc(640, true, 10),
bc(1280, true, 20),
bc(1280, false, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
height
} else {
768
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
768
};
Self {
width,
height,
clip: clip::Config::v2_1(),
autoencoder,
scheduler,
unet,
}
}
pub fn v2_1(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
Self::v2_1_(
sliced_attention_size,
height,
width,
PredictionType::VPrediction,
)
}
pub fn v2_1_inpaint(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
// This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
// type being "epsilon" by default and not "v_prediction".
Self::v2_1_(
sliced_attention_size,
height,
width,
PredictionType::Epsilon,
)
}
pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
let weights = weights.deserialize()?;
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
Ok(autoencoder)
}
pub fn build_unet(
&self,
unet_weights: &str,
device: &Device,
in_channels: usize,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
Ok(unet)
}
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
pub fn build_clip_transformer(
&self,
clip_weights: &str,
device: &Device,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
Ok(text_model)
}
}

View File

@ -15,3 +15,7 @@ pub fn pad(_: &Tensor) -> Result<Tensor> {
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn linspace(_: f64, _: f64, _: usize) -> Result<Tensor> {
todo!()
}