From 5082954c523c82a76ce91622ae9c3966dce0dc89 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 18 Sep 2023 14:50:14 +0100 Subject: [PATCH] Fix the W clip embeddings. (#887) * Fix the W clip embeddings. * Add the specialized ddpm scheduler. --- candle-examples/examples/wuerstchen/main.rs | 6 +- .../src/models/wuerstchen/ddpm.rs | 97 +++++++++++++++++++ .../src/models/wuerstchen/mod.rs | 1 + 3 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 candle-transformers/src/models/wuerstchen/ddpm.rs diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index b3231360..12e4c10e 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -193,9 +193,9 @@ fn encode_prompt( println!("Building the clip transformer."); let text_model = stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; - let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?; - let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?; - let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; + let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?; + let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?; + let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; Ok(text_embeddings) } diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs new file mode 100644 index 00000000..f4f16bfb --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/ddpm.rs @@ -0,0 +1,97 @@ +use candle::{Result, Tensor}; + +#[derive(Debug, Clone)] +pub struct DDPMWSchedulerConfig { + scaler: f64, + s: f64, +} + +impl Default for DDPMWSchedulerConfig { + fn default() -> Self { + Self { + scaler: 1f64, + s: 0.008f64, + } + } +} + +pub struct DDPMWScheduler { + init_alpha_cumprod: f64, + init_noise_sigma: f64, + timesteps: Vec, + pub config: DDPMWSchedulerConfig, +} + +impl DDPMWScheduler { + pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result { + let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI) + .cos() + .powi(2); + let timesteps = (0..=inference_steps) + .map(|i| 1. - i as f64 / inference_steps as f64) + .collect::>(); + Ok(Self { + init_alpha_cumprod, + init_noise_sigma: 1.0, + timesteps, + config, + }) + } + + fn alpha_cumprod(&self, t: f64) -> f64 { + let scaler = self.config.scaler; + let s = self.config.s; + let t = if scaler > 1. { + 1. - (1. - t).powf(scaler) + } else if scaler < 1. { + t.powf(scaler) + } else { + t + }; + let alpha_cumprod = + ((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod; + alpha_cumprod.clamp(0.0001, 0.9999) + } + + fn previous_timestep(&self, ts: f64) -> f64 { + let index = self + .timesteps + .iter() + .enumerate() + .map(|(idx, v)| (idx, (v - ts).abs())) + .min_by(|x, y| x.1.total_cmp(&y.1)) + .unwrap() + .0; + self.timesteps[index + 1] + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result { + let prev_t = self.previous_timestep(ts); + + let alpha_cumprod = self.alpha_cumprod(ts); + let alpha_cumprod_prev = self.alpha_cumprod(prev_t); + let alpha = alpha_cumprod / alpha_cumprod_prev; + + let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?; + let mu = (mu * (1. / alpha).sqrt())?; + + let std_noise = mu.randn_like(0., 1.)?; + let std = + std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt(); + if prev_t == 0. { + Ok(mu) + } else { + mu + std + } + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs index 435bdac2..f499bc35 100644 --- a/candle-transformers/src/models/wuerstchen/mod.rs +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -1,4 +1,5 @@ pub mod common; +pub mod ddpm; pub mod diffnext; pub mod paella_vq; pub mod prior;