diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 32c7d158..7e4d2360 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -15,6 +15,7 @@ use clap::Parser; use tokenizers::Tokenizer; const GUIDANCE_SCALE: f64 = 7.5; +const RESOLUTION_MULTIPLE: f64 = 42.67; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -217,6 +218,8 @@ fn run(args: Args) -> Result<()> { }; let device = candle_examples::device(cpu)?; + let height = height.unwrap_or(1024); + let width = width.unwrap_or(1024); let text_embeddings = encode_prompt( &prompt, @@ -225,12 +228,12 @@ fn run(args: Args) -> Result<()> { clip_weights.clone(), stable_diffusion::clip::Config::wuerstchen(), &device, - ); + )?; println!("{text_embeddings:?}"); println!("Building the prior."); // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json - let _prior = { + let prior = { let prior_weights = ModelFile::Prior.get(prior_weights)?; let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; let weights = weights.deserialize()?; @@ -238,7 +241,7 @@ fn run(args: Args) -> Result<()> { wuerstchen::prior::WPrior::new( /* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, - ) + )? }; println!("Building the vqgan."); @@ -264,8 +267,21 @@ fn run(args: Args) -> Result<()> { )? }; - let _bsize = 1; + let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let b_size = 1; for idx in 0..num_samples { + let latents = Tensor::randn( + 0f32, + 1f32, + (b_size, 4, latent_height, latent_width), + &device, + )?; + // TODO: latents denoising loop, use the scheduler values. + let ratio = Tensor::ones(1, DType::F32, &device)?; + let prior = prior.forward(&latents, &ratio, &text_embeddings)?; + + let latents = ((latents * 42.)? - 1.)?; /* let timesteps = scheduler.timesteps(); let latents = Tensor::randn( diff --git a/candle-transformers/src/models/stable_diffusion/ddpm.rs b/candle-transformers/src/models/stable_diffusion/ddpm.rs new file mode 100644 index 00000000..d393f39a --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/ddpm.rs @@ -0,0 +1,205 @@ +use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use candle::{Result, Tensor}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DDPMVarianceType { + FixedSmall, + FixedSmallLog, + FixedLarge, + FixedLargeLog, + Learned, +} + +impl Default for DDPMVarianceType { + fn default() -> Self { + Self::FixedSmall + } +} + +#[derive(Debug, Clone)] +pub struct DDPMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Option to predicted sample between -1 and 1 for numerical stability. + pub clip_sample: bool, + /// Option to clip the variance used when adding noise to the denoised sample. + pub variance_type: DDPMVarianceType, + /// prediction type of the scheduler function + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model. + pub train_timesteps: usize, +} + +impl Default for DDPMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + clip_sample: false, + variance_type: DDPMVarianceType::FixedSmall, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +pub struct DDPMScheduler { + alphas_cumprod: Vec, + init_noise_sigma: f64, + timesteps: Vec, + step_ratio: usize, + pub config: DDPMSchedulerConfig, +} + +impl DDPMScheduler { + pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result { + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => super::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + + let betas = betas.to_vec1::()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + + // min(train_timesteps, inference_steps) + // https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187 + let inference_steps = inference_steps.min(config.train_timesteps); + // arange the number of the scheduler's timesteps + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec = (0..inference_steps).map(|s| s * step_ratio).rev().collect(); + + Ok(Self { + alphas_cumprod, + init_noise_sigma: 1.0, + timesteps, + step_ratio, + config, + }) + } + + fn get_variance(&self, timestep: usize) -> f64 { + let prev_t = timestep as isize - self.step_ratio as isize; + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev; + + // For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + // and sample from it to get previous sample + // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t; + + // retrieve variance + match self.config.variance_type { + DDPMVarianceType::FixedSmall => variance.max(1e-20), + // for rl-diffuser https://arxiv.org/abs/2205.09991 + DDPMVarianceType::FixedSmallLog => { + let variance = variance.max(1e-20).ln(); + (variance * 0.5).exp() + } + DDPMVarianceType::FixedLarge => current_beta_t, + DDPMVarianceType::FixedLargeLog => current_beta_t.ln(), + DDPMVarianceType::Learned => variance, + } + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result { + let prev_t = timestep as isize - self.step_ratio as isize; + + // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272 + // 1. compute alphas, betas + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let beta_prod_t = 1. - alpha_prod_t; + let beta_prod_t_prev = 1. - alpha_prod_t_prev; + let current_alpha_t = alpha_prod_t / alpha_prod_t_prev; + let current_beta_t = 1. - current_alpha_t; + + // 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15) + let mut pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => { + ((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())? + } + PredictionType::Sample => model_output.clone(), + PredictionType::VPrediction => { + ((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())? + } + }; + + // 3. clip predicted x_0 + if self.config.clip_sample { + pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?; + } + + // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t; + let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t; + + // 5. Compute predicted previous sample µ_t + // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)? + + sample * current_sample_coeff)?; + + // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305 + // 6. Add noise + let mut variance = model_output.zeros_like()?; + if timestep > 0 { + let variance_noise = model_output.randn_like(0., 1.)?; + if self.config.variance_type == DDPMVarianceType::FixedSmallLog { + variance = (variance_noise * self.get_variance(timestep))?; + } else { + variance = (variance_noise * self.get_variance(timestep).sqrt())?; + } + } + &pred_prev_sample + variance + } + + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Result { + (original_samples * self.alphas_cumprod[timestep].sqrt())? + + noise * (1. - self.alphas_cumprod[timestep]).sqrt() + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index d9721532..c6f1b904 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -1,6 +1,7 @@ pub mod attention; pub mod clip; pub mod ddim; +pub mod ddpm; pub mod embeddings; pub mod resnet; pub mod schedulers;