mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the DDPM scheduler. (#877)
* Add the DDPM scheduler. * Minor tweaks.
This commit is contained in:
@ -15,6 +15,7 @@ use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const GUIDANCE_SCALE: f64 = 7.5;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -217,6 +218,8 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let height = height.unwrap_or(1024);
|
||||
let width = width.unwrap_or(1024);
|
||||
|
||||
let text_embeddings = encode_prompt(
|
||||
&prompt,
|
||||
@ -225,12 +228,12 @@ fn run(args: Args) -> Result<()> {
|
||||
clip_weights.clone(),
|
||||
stable_diffusion::clip::Config::wuerstchen(),
|
||||
&device,
|
||||
);
|
||||
)?;
|
||||
println!("{text_embeddings:?}");
|
||||
|
||||
println!("Building the prior.");
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||
let _prior = {
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
@ -238,7 +241,7 @@ fn run(args: Args) -> Result<()> {
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64,
|
||||
/* depth */ 32, /* nhead */ 24, vb,
|
||||
)
|
||||
)?
|
||||
};
|
||||
|
||||
println!("Building the vqgan.");
|
||||
@ -264,8 +267,21 @@ fn run(args: Args) -> Result<()> {
|
||||
)?
|
||||
};
|
||||
|
||||
let _bsize = 1;
|
||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let b_size = 1;
|
||||
for idx in 0..num_samples {
|
||||
let latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, 4, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
// TODO: latents denoising loop, use the scheduler values.
|
||||
let ratio = Tensor::ones(1, DType::F32, &device)?;
|
||||
let prior = prior.forward(&latents, &ratio, &text_embeddings)?;
|
||||
|
||||
let latents = ((latents * 42.)? - 1.)?;
|
||||
/*
|
||||
let timesteps = scheduler.timesteps();
|
||||
let latents = Tensor::randn(
|
||||
|
205
candle-transformers/src/models/stable_diffusion/ddpm.rs
Normal file
205
candle-transformers/src/models/stable_diffusion/ddpm.rs
Normal file
@ -0,0 +1,205 @@
|
||||
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum DDPMVarianceType {
|
||||
FixedSmall,
|
||||
FixedSmallLog,
|
||||
FixedLarge,
|
||||
FixedLargeLog,
|
||||
Learned,
|
||||
}
|
||||
|
||||
impl Default for DDPMVarianceType {
|
||||
fn default() -> Self {
|
||||
Self::FixedSmall
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DDPMSchedulerConfig {
|
||||
/// The value of beta at the beginning of training.
|
||||
pub beta_start: f64,
|
||||
/// The value of beta at the end of training.
|
||||
pub beta_end: f64,
|
||||
/// How beta evolved during training.
|
||||
pub beta_schedule: BetaSchedule,
|
||||
/// Option to predicted sample between -1 and 1 for numerical stability.
|
||||
pub clip_sample: bool,
|
||||
/// Option to clip the variance used when adding noise to the denoised sample.
|
||||
pub variance_type: DDPMVarianceType,
|
||||
/// prediction type of the scheduler function
|
||||
pub prediction_type: PredictionType,
|
||||
/// number of diffusion steps used to train the model.
|
||||
pub train_timesteps: usize,
|
||||
}
|
||||
|
||||
impl Default for DDPMSchedulerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
beta_start: 0.00085,
|
||||
beta_end: 0.012,
|
||||
beta_schedule: BetaSchedule::ScaledLinear,
|
||||
clip_sample: false,
|
||||
variance_type: DDPMVarianceType::FixedSmall,
|
||||
prediction_type: PredictionType::Epsilon,
|
||||
train_timesteps: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DDPMScheduler {
|
||||
alphas_cumprod: Vec<f64>,
|
||||
init_noise_sigma: f64,
|
||||
timesteps: Vec<usize>,
|
||||
step_ratio: usize,
|
||||
pub config: DDPMSchedulerConfig,
|
||||
}
|
||||
|
||||
impl DDPMScheduler {
|
||||
pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result<Self> {
|
||||
let betas = match config.beta_schedule {
|
||||
BetaSchedule::ScaledLinear => super::utils::linspace(
|
||||
config.beta_start.sqrt(),
|
||||
config.beta_end.sqrt(),
|
||||
config.train_timesteps,
|
||||
)?
|
||||
.sqr()?,
|
||||
BetaSchedule::Linear => {
|
||||
super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
|
||||
}
|
||||
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
|
||||
};
|
||||
|
||||
let betas = betas.to_vec1::<f64>()?;
|
||||
let mut alphas_cumprod = Vec::with_capacity(betas.len());
|
||||
for &beta in betas.iter() {
|
||||
let alpha = 1.0 - beta;
|
||||
alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
|
||||
}
|
||||
|
||||
// min(train_timesteps, inference_steps)
|
||||
// https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187
|
||||
let inference_steps = inference_steps.min(config.train_timesteps);
|
||||
// arange the number of the scheduler's timesteps
|
||||
let step_ratio = config.train_timesteps / inference_steps;
|
||||
let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect();
|
||||
|
||||
Ok(Self {
|
||||
alphas_cumprod,
|
||||
init_noise_sigma: 1.0,
|
||||
timesteps,
|
||||
step_ratio,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_variance(&self, timestep: usize) -> f64 {
|
||||
let prev_t = timestep as isize - self.step_ratio as isize;
|
||||
let alpha_prod_t = self.alphas_cumprod[timestep];
|
||||
let alpha_prod_t_prev = if prev_t >= 0 {
|
||||
self.alphas_cumprod[prev_t as usize]
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev;
|
||||
|
||||
// For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
// and sample from it to get previous sample
|
||||
// x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||
let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t;
|
||||
|
||||
// retrieve variance
|
||||
match self.config.variance_type {
|
||||
DDPMVarianceType::FixedSmall => variance.max(1e-20),
|
||||
// for rl-diffuser https://arxiv.org/abs/2205.09991
|
||||
DDPMVarianceType::FixedSmallLog => {
|
||||
let variance = variance.max(1e-20).ln();
|
||||
(variance * 0.5).exp()
|
||||
}
|
||||
DDPMVarianceType::FixedLarge => current_beta_t,
|
||||
DDPMVarianceType::FixedLargeLog => current_beta_t.ln(),
|
||||
DDPMVarianceType::Learned => variance,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timesteps(&self) -> &[usize] {
|
||||
self.timesteps.as_slice()
|
||||
}
|
||||
|
||||
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||
/// depending on the current timestep.
|
||||
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
|
||||
sample
|
||||
}
|
||||
|
||||
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let prev_t = timestep as isize - self.step_ratio as isize;
|
||||
|
||||
// https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272
|
||||
// 1. compute alphas, betas
|
||||
let alpha_prod_t = self.alphas_cumprod[timestep];
|
||||
let alpha_prod_t_prev = if prev_t >= 0 {
|
||||
self.alphas_cumprod[prev_t as usize]
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
let beta_prod_t = 1. - alpha_prod_t;
|
||||
let beta_prod_t_prev = 1. - alpha_prod_t_prev;
|
||||
let current_alpha_t = alpha_prod_t / alpha_prod_t_prev;
|
||||
let current_beta_t = 1. - current_alpha_t;
|
||||
|
||||
// 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15)
|
||||
let mut pred_original_sample = match self.config.prediction_type {
|
||||
PredictionType::Epsilon => {
|
||||
((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())?
|
||||
}
|
||||
PredictionType::Sample => model_output.clone(),
|
||||
PredictionType::VPrediction => {
|
||||
((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())?
|
||||
}
|
||||
};
|
||||
|
||||
// 3. clip predicted x_0
|
||||
if self.config.clip_sample {
|
||||
pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?;
|
||||
}
|
||||
|
||||
// 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
// See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t;
|
||||
let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t;
|
||||
|
||||
// 5. Compute predicted previous sample µ_t
|
||||
// See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)?
|
||||
+ sample * current_sample_coeff)?;
|
||||
|
||||
// https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305
|
||||
// 6. Add noise
|
||||
let mut variance = model_output.zeros_like()?;
|
||||
if timestep > 0 {
|
||||
let variance_noise = model_output.randn_like(0., 1.)?;
|
||||
if self.config.variance_type == DDPMVarianceType::FixedSmallLog {
|
||||
variance = (variance_noise * self.get_variance(timestep))?;
|
||||
} else {
|
||||
variance = (variance_noise * self.get_variance(timestep).sqrt())?;
|
||||
}
|
||||
}
|
||||
&pred_prev_sample + variance
|
||||
}
|
||||
|
||||
pub fn add_noise(
|
||||
&self,
|
||||
original_samples: &Tensor,
|
||||
noise: Tensor,
|
||||
timestep: usize,
|
||||
) -> Result<Tensor> {
|
||||
(original_samples * self.alphas_cumprod[timestep].sqrt())?
|
||||
+ noise * (1. - self.alphas_cumprod[timestep]).sqrt()
|
||||
}
|
||||
|
||||
pub fn init_noise_sigma(&self) -> f64 {
|
||||
self.init_noise_sigma
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
pub mod attention;
|
||||
pub mod clip;
|
||||
pub mod ddim;
|
||||
pub mod ddpm;
|
||||
pub mod embeddings;
|
||||
pub mod resnet;
|
||||
pub mod schedulers;
|
||||
|
Reference in New Issue
Block a user