From ec895453cdb0f27832e9ec28d219cefe77ffb875 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 19 Sep 2023 13:43:19 +0100 Subject: [PATCH] More shape fixes. --- candle-examples/examples/wuerstchen/main.rs | 21 ++++++++++++------- .../src/models/wuerstchen/diffnext.rs | 1 + 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 46403e47..265ebf06 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -16,6 +16,7 @@ use tokenizers::Tokenizer; const PRIOR_GUIDANCE_SCALE: f64 = 8.0; const RESOLUTION_MULTIPLE: f64 = 42.67; +const LATENT_DIM_SCALE: f64 = 10.67; const PRIOR_CIN: usize = 16; const DECODER_CIN: usize = 4; @@ -270,11 +271,11 @@ fn run(args: Args) -> Result<()> { println!("generated text embeddings {text_embeddings:?}"); println!("Building the prior."); - // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json - let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; - let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; let b_size = 1; - let effnet = { + let image_embeddings = { + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json + let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; let mut latents = Tensor::randn( 0f32, 1f32, @@ -296,6 +297,7 @@ fn run(args: Args) -> Result<()> { let timesteps = prior_scheduler.timesteps(); println!("prior denoising"); for (index, &t) in timesteps.iter().enumerate() { + continue; let start_time = std::time::Instant::now(); if index == timesteps.len() - 1 { continue; @@ -343,6 +345,10 @@ fn run(args: Args) -> Result<()> { }; for idx in 0..num_samples { + // https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json + let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize; + let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize; + let mut latents = Tensor::randn( 0f32, 1f32, @@ -350,7 +356,7 @@ fn run(args: Args) -> Result<()> { &device, )?; - println!("diffusion process with prior {effnet:?}"); + println!("diffusion process with prior {image_embeddings:?}"); let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; let timesteps = scheduler.timesteps(); for (index, &t) in timesteps.iter().enumerate() { @@ -358,8 +364,9 @@ fn run(args: Args) -> Result<()> { if index == timesteps.len() - 1 { continue; } - let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; - let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?; + let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?; + let noise_pred = + decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?; latents = scheduler.step(&noise_pred, t, &latents)?; let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index afa83a16..60b799ae 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -335,6 +335,7 @@ impl WDiffNeXt { level_outputs.push(xs.clone()) } level_outputs.reverse(); + let mut xs = level_outputs[0].clone(); for (i, up_block) in self.up_blocks.iter().enumerate() { let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {