mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the ddim scheduler. (#330)
This commit is contained in:
181
candle-examples/examples/stable-diffusion/ddim.rs
Normal file
181
candle-examples/examples/stable-diffusion/ddim.rs
Normal file
@ -0,0 +1,181 @@
|
||||
#![allow(dead_code)]
|
||||
//! # Denoising Diffusion Implicit Models
|
||||
//!
|
||||
//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
|
||||
//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
|
||||
//! generative process is the reverse of a Markovian process, DDIM generalizes
|
||||
//! this to non-Markovian guidance.
|
||||
//!
|
||||
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
|
||||
//! https://arxiv.org/abs/2010.02502
|
||||
use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// The configuration for the DDIM scheduler.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct DDIMSchedulerConfig {
|
||||
/// The value of beta at the beginning of training.
|
||||
pub beta_start: f64,
|
||||
/// The value of beta at the end of training.
|
||||
pub beta_end: f64,
|
||||
/// How beta evolved during training.
|
||||
pub beta_schedule: BetaSchedule,
|
||||
/// The amount of noise to be added at each step.
|
||||
pub eta: f64,
|
||||
/// Adjust the indexes of the inference schedule by this value.
|
||||
pub steps_offset: usize,
|
||||
/// prediction type of the scheduler function, one of `epsilon` (predicting
|
||||
/// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
|
||||
/// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
|
||||
pub prediction_type: PredictionType,
|
||||
/// number of diffusion steps used to train the model
|
||||
pub train_timesteps: usize,
|
||||
}
|
||||
|
||||
impl Default for DDIMSchedulerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
beta_start: 0.00085f64,
|
||||
beta_end: 0.012f64,
|
||||
beta_schedule: BetaSchedule::ScaledLinear,
|
||||
eta: 0.,
|
||||
steps_offset: 1,
|
||||
prediction_type: PredictionType::Epsilon,
|
||||
train_timesteps: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The DDIM scheduler.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DDIMScheduler {
|
||||
timesteps: Vec<usize>,
|
||||
alphas_cumprod: Vec<f64>,
|
||||
step_ratio: usize,
|
||||
init_noise_sigma: f64,
|
||||
pub config: DDIMSchedulerConfig,
|
||||
}
|
||||
|
||||
// clip_sample: False, set_alpha_to_one: False
|
||||
impl DDIMScheduler {
|
||||
/// Creates a new DDIM scheduler given the number of steps to be
|
||||
/// used for inference as well as the number of steps that was used
|
||||
/// during training.
|
||||
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
|
||||
let step_ratio = config.train_timesteps / inference_steps;
|
||||
let timesteps: Vec<usize> = (0..(inference_steps))
|
||||
.map(|s| s * step_ratio + config.steps_offset)
|
||||
.rev()
|
||||
.collect();
|
||||
let betas = match config.beta_schedule {
|
||||
BetaSchedule::ScaledLinear => crate::utils::linspace(
|
||||
config.beta_start.sqrt(),
|
||||
config.beta_end.sqrt(),
|
||||
config.train_timesteps,
|
||||
)?
|
||||
.sqr()?,
|
||||
BetaSchedule::Linear => {
|
||||
crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
|
||||
}
|
||||
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
|
||||
};
|
||||
let betas = betas.to_vec1::<f64>()?;
|
||||
let mut alphas_cumprod = Vec::with_capacity(betas.len());
|
||||
for &beta in betas.iter() {
|
||||
let alpha = 1.0 - beta;
|
||||
alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
|
||||
}
|
||||
Ok(Self {
|
||||
alphas_cumprod,
|
||||
timesteps,
|
||||
step_ratio,
|
||||
init_noise_sigma: 1.,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
|
||||
sample
|
||||
}
|
||||
|
||||
/// Performs a backward step during inference.
|
||||
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
timestep
|
||||
};
|
||||
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
|
||||
let prev_timestep = if timestep > self.step_ratio {
|
||||
timestep - self.step_ratio
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let alpha_prod_t = self.alphas_cumprod[timestep];
|
||||
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
|
||||
let beta_prod_t = 1. - alpha_prod_t;
|
||||
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
|
||||
|
||||
let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
|
||||
PredictionType::Epsilon => {
|
||||
let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
|
||||
* (1. / alpha_prod_t.sqrt()))?;
|
||||
(pred_original_sample, model_output.clone())
|
||||
}
|
||||
PredictionType::VPrediction => {
|
||||
let pred_original_sample =
|
||||
((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
|
||||
let pred_epsilon =
|
||||
((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
|
||||
(pred_original_sample, pred_epsilon)
|
||||
}
|
||||
PredictionType::Sample => {
|
||||
let pred_original_sample = model_output.clone();
|
||||
let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
|
||||
* (1. / beta_prod_t.sqrt()))?;
|
||||
(pred_original_sample, pred_epsilon)
|
||||
}
|
||||
};
|
||||
|
||||
let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
|
||||
let std_dev_t = self.config.eta * variance.sqrt();
|
||||
|
||||
let pred_sample_direction =
|
||||
(pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
|
||||
let prev_sample =
|
||||
((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
|
||||
if self.config.eta > 0. {
|
||||
&prev_sample
|
||||
+ Tensor::randn(
|
||||
0f32,
|
||||
std_dev_t as f32,
|
||||
prev_sample.shape(),
|
||||
prev_sample.device(),
|
||||
)?
|
||||
} else {
|
||||
Ok(prev_sample)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
timestep
|
||||
};
|
||||
let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
|
||||
let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
|
||||
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
self.init_noise_sigma
|
||||
}
|
||||
}
|
@ -3,8 +3,11 @@ extern crate intel_mkl_src;
|
||||
|
||||
mod attention;
|
||||
mod clip;
|
||||
mod ddim;
|
||||
mod embeddings;
|
||||
mod resnet;
|
||||
mod schedulers;
|
||||
mod stable_diffusion;
|
||||
mod unet_2d;
|
||||
mod unet_2d_blocks;
|
||||
mod utils;
|
||||
|
45
candle-examples/examples/stable-diffusion/schedulers.rs
Normal file
45
candle-examples/examples/stable-diffusion/schedulers.rs
Normal file
@ -0,0 +1,45 @@
|
||||
#![allow(dead_code)]
|
||||
//! # Diffusion pipelines and models
|
||||
//!
|
||||
//! Noise schedulers can be used to set the trade-off between
|
||||
//! inference speed and quality.
|
||||
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
/// This represents how beta ranges from its minimum value to the maximum
|
||||
/// during training.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum BetaSchedule {
|
||||
/// Linear interpolation.
|
||||
Linear,
|
||||
/// Linear interpolation of the square root of beta.
|
||||
ScaledLinear,
|
||||
/// Glide cosine schedule
|
||||
SquaredcosCapV2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PredictionType {
|
||||
Epsilon,
|
||||
VPrediction,
|
||||
Sample,
|
||||
}
|
||||
|
||||
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
/// `(1-beta)` over time from `t = [0,1]`.
|
||||
///
|
||||
/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
|
||||
/// up to that part of the diffusion process.
|
||||
pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
|
||||
let alpha_bar = |time_step: usize| {
|
||||
f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
|
||||
};
|
||||
let mut betas = Vec::with_capacity(num_diffusion_timesteps);
|
||||
for i in 0..num_diffusion_timesteps {
|
||||
let t1 = i / num_diffusion_timesteps;
|
||||
let t2 = (i + 1) / num_diffusion_timesteps;
|
||||
betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
|
||||
}
|
||||
let betas_len = betas.len();
|
||||
Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
|
||||
}
|
212
candle-examples/examples/stable-diffusion/stable_diffusion.rs
Normal file
212
candle-examples/examples/stable-diffusion/stable_diffusion.rs
Normal file
@ -0,0 +1,212 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::schedulers::PredictionType;
|
||||
use crate::{clip, ddim, unet_2d, vae};
|
||||
use candle::{DType, Device, Result};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StableDiffusionConfig {
|
||||
pub width: usize,
|
||||
pub height: usize,
|
||||
pub clip: clip::Config,
|
||||
autoencoder: vae::AutoEncoderKLConfig,
|
||||
unet: unet_2d::UNet2DConditionModelConfig,
|
||||
scheduler: ddim::DDIMSchedulerConfig,
|
||||
}
|
||||
|
||||
impl StableDiffusionConfig {
|
||||
pub fn v1_5(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, true, 8),
|
||||
bc(640, true, 8),
|
||||
bc(1280, true, 8),
|
||||
bc(1280, false, 8),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 768,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: false,
|
||||
};
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
512
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v1_5(),
|
||||
autoencoder,
|
||||
scheduler: Default::default(),
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
fn v2_1_(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
prediction_type: PredictionType,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, true, 5),
|
||||
bc(640, true, 10),
|
||||
bc(1280, true, 20),
|
||||
bc(1280, false, 20),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 1024,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: true,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
768
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
768
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v2_1(),
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn v2_1(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
|
||||
Self::v2_1_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
PredictionType::VPrediction,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn v2_1_inpaint(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
|
||||
// This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
|
||||
// type being "epsilon" by default and not "v_prediction".
|
||||
Self::v2_1_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
PredictionType::Epsilon,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
|
||||
Ok(autoencoder)
|
||||
}
|
||||
|
||||
pub fn build_unet(
|
||||
&self,
|
||||
unet_weights: &str,
|
||||
device: &Device,
|
||||
in_channels: usize,
|
||||
) -> Result<unet_2d::UNet2DConditionModel> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
|
||||
Ok(unet)
|
||||
}
|
||||
|
||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
|
||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||
}
|
||||
|
||||
pub fn build_clip_transformer(
|
||||
&self,
|
||||
clip_weights: &str,
|
||||
device: &Device,
|
||||
) -> Result<clip::ClipTextTransformer> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
|
||||
Ok(text_model)
|
||||
}
|
||||
}
|
@ -15,3 +15,7 @@ pub fn pad(_: &Tensor) -> Result<Tensor> {
|
||||
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn linspace(_: f64, _: f64, _: usize) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
|
Reference in New Issue
Block a user