From 06e46d7c3bb4ee4ca3ae0a64e9c2add95f5e0fb3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 19 Sep 2023 14:13:05 +0100 Subject: [PATCH] Only use classifier free guidance for the prior. (#896) * Only use classifier free guidance for the prior. * Add another specific layer-norm structure. * Tweaks. * Fix the latent shape. * Print the prior shape. * More shape fixes. * Remove some debugging continue. --- candle-examples/examples/wuerstchen/main.rs | 154 ++++++++++-------- .../src/models/wuerstchen/common.rs | 28 ++++ .../src/models/wuerstchen/diffnext.rs | 7 +- .../src/models/wuerstchen/paella_vq.rs | 10 +- 4 files changed, 125 insertions(+), 74 deletions(-) diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index b92fe8fd..4e4bce0b 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -16,7 +16,9 @@ use tokenizers::Tokenizer; const PRIOR_GUIDANCE_SCALE: f64 = 8.0; const RESOLUTION_MULTIPLE: f64 = 42.67; +const LATENT_DIM_SCALE: f64 = 10.67; const PRIOR_CIN: usize = 16; +const DECODER_CIN: usize = 4; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -156,7 +158,7 @@ fn output_filename( fn encode_prompt( prompt: &str, - uncond_prompt: &str, + uncond_prompt: Option<&str>, tokenizer: std::path::PathBuf, clip_weights: std::path::PathBuf, clip_config: stable_diffusion::clip::Config, @@ -179,24 +181,30 @@ fn encode_prompt( } let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; - let mut uncond_tokens = tokenizer - .encode(uncond_prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - let uncond_tokens_len = uncond_tokens.len(); - while uncond_tokens.len() < clip_config.max_position_embeddings { - uncond_tokens.push(pad_id) - } - let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; - println!("Building the clip transformer."); let text_model = stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?; - let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?; - let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; - Ok(text_embeddings) + match uncond_prompt { + None => Ok(text_embeddings), + Some(uncond_prompt) => { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let uncond_tokens_len = uncond_tokens.len(); + while uncond_tokens.len() < clip_config.max_position_embeddings { + uncond_tokens.push(pad_id) + } + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + + let uncond_embeddings = + text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?; + let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; + Ok(text_embeddings) + } + } } fn run(args: Args) -> Result<()> { @@ -239,40 +247,72 @@ fn run(args: Args) -> Result<()> { let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?; encode_prompt( &prompt, - &uncond_prompt, + Some(&uncond_prompt), tokenizer.clone(), weights, stable_diffusion::clip::Config::wuerstchen_prior(), &device, )? }; - println!("{prior_text_embeddings}"); + println!("generated prior text embeddings {prior_text_embeddings:?}"); let text_embeddings = { let tokenizer = ModelFile::Tokenizer.get(tokenizer)?; let weights = ModelFile::Clip.get(clip_weights)?; encode_prompt( &prompt, - &uncond_prompt, + None, tokenizer.clone(), weights, stable_diffusion::clip::Config::wuerstchen(), &device, )? }; - println!("{prior_text_embeddings}"); + println!("generated text embeddings {text_embeddings:?}"); println!("Building the prior."); - // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json - let prior = { - let prior_weights = ModelFile::Prior.get(prior_weights)?; - let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; - let weights = weights.deserialize()?; - let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - wuerstchen::prior::WPrior::new( - /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280, - /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, - )? + let b_size = 1; + let image_embeddings = { + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json + let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let mut latents = Tensor::randn( + 0f32, + 1f32, + (b_size, PRIOR_CIN, latent_height, latent_width), + &device, + )?; + + let prior = { + let prior_weights = ModelFile::Prior.get(prior_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::prior::WPrior::new( + /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280, + /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, + )? + }; + let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; + let timesteps = prior_scheduler.timesteps(); + println!("prior denoising"); + for (index, &t) in timesteps.iter().enumerate() { + let start_time = std::time::Instant::now(); + if index == timesteps.len() - 1 { + continue; + } + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; + let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?; + latents = prior_scheduler.step(&noise_pred, t, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); + } + ((latents * 42.)? - 1.)? }; println!("Building the vqgan."); @@ -293,58 +333,40 @@ fn run(args: Args) -> Result<()> { let weights = weights.deserialize()?; let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); wuerstchen::diffnext::WDiffNeXt::new( - /* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024, - /* clip_embd */ 1024, /* patch_size */ 2, vb, + /* c_in */ DECODER_CIN, + /* c_out */ DECODER_CIN, + /* c_r */ 64, + /* c_cond */ 1024, + /* clip_embd */ 1024, + /* patch_size */ 2, + vb, )? }; - let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; - let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; - let b_size = 1; for idx in 0..num_samples { + // https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json + let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize; + let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize; + let mut latents = Tensor::randn( 0f32, 1f32, - (b_size, PRIOR_CIN, latent_height, latent_width), + (b_size, DECODER_CIN, latent_height, latent_width), &device, )?; - let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; - let timesteps = prior_scheduler.timesteps(); - println!("prior denoising"); + println!("diffusion process with prior {image_embeddings:?}"); + let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; + let timesteps = scheduler.timesteps(); for (index, &t) in timesteps.iter().enumerate() { let start_time = std::time::Instant::now(); if index == timesteps.len() - 1 { continue; } - let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; - let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; - let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; - let noise_pred = noise_pred.chunk(2, 0)?; - let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]); - let noise_pred = (noise_pred_uncond - + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?; - latents = prior_scheduler.step(&noise_pred, t, &latents)?; - let dt = start_time.elapsed().as_secs_f32(); - println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); - } - let effnet = ((latents * 42.)? - 1.)?; - let mut latents = Tensor::randn( - 0f32, - 1f32, - (b_size, PRIOR_CIN, latent_height, latent_width), - &device, - )?; - - println!("diffusion process"); - for (index, &t) in timesteps.iter().enumerate() { - let start_time = std::time::Instant::now(); - if index == timesteps.len() - 1 { - continue; - } - let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; - let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?; - latents = prior_scheduler.step(&noise_pred, t, &latents)?; + let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?; + let noise_pred = + decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?; + latents = scheduler.step(&noise_pred, t, &latents)?; let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); } diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index 1eb0c2e7..3cac2a59 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -34,6 +34,34 @@ impl Module for WLayerNorm { } } +#[derive(Debug)] +pub struct LayerNormNoWeights { + eps: f64, +} + +impl LayerNormNoWeights { + pub fn new(_size: usize) -> Result { + Ok(Self { eps: 1e-6 }) + } +} + +impl Module for LayerNormNoWeights { + fn forward(&self, xs: &Tensor) -> Result { + let x_dtype = xs.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = xs.dim(D::Minus1)?; + let xs = xs.to_dtype(internal_dtype)?; + let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let xs = xs.broadcast_sub(&mean_x)?; + let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype) + } +} + #[derive(Debug)] pub struct TimestepBlock { mapper: candle_nn::Linear, diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 664251ed..60b799ae 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -1,4 +1,4 @@ -use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm}; +use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm}; use candle::{DType, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -75,7 +75,7 @@ struct UpBlock { pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, effnet_mappers: Vec>, - seq_norm: WLayerNorm, + seq_norm: LayerNormNoWeights, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, down_blocks: Vec, @@ -133,7 +133,7 @@ impl WDiffNeXt { }; effnet_mappers.push(c) } - let seq_norm = WLayerNorm::new(c_cond)?; + let seq_norm = LayerNormNoWeights::new(c_cond)?; let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; let embedding_conv = candle_nn::conv2d( c_in * patch_size * patch_size, @@ -335,6 +335,7 @@ impl WDiffNeXt { level_outputs.push(xs.clone()) } level_outputs.reverse(); + let mut xs = level_outputs[0].clone(); for (i, up_block) in self.up_blocks.iter().enumerate() { let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] { diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index faf2d2b4..8cf33505 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -1,12 +1,12 @@ -use super::common::WLayerNorm; +use super::common::LayerNormNoWeights; use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; #[derive(Debug)] pub struct MixingResidualBlock { - norm1: WLayerNorm, + norm1: LayerNormNoWeights, depthwise_conv: candle_nn::Conv2d, - norm2: WLayerNorm, + norm2: LayerNormNoWeights, channelwise_lin1: candle_nn::Linear, channelwise_lin2: candle_nn::Linear, gammas: Vec, @@ -14,8 +14,8 @@ pub struct MixingResidualBlock { impl MixingResidualBlock { pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result { - let norm1 = WLayerNorm::new(inp)?; - let norm2 = WLayerNorm::new(inp)?; + let norm1 = LayerNormNoWeights::new(inp)?; + let norm2 = LayerNormNoWeights::new(inp)?; let cfg = candle_nn::Conv2dConfig { groups: inp, ..Default::default()