From 5d8e214dfe74a0bf686aca0b5114a16157c2a72a Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 19 Sep 2023 09:21:35 +0100 Subject: [PATCH] Fix the latent shape. --- candle-examples/examples/wuerstchen/main.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 1428faa2..ac0526af 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -17,6 +17,7 @@ use tokenizers::Tokenizer; const PRIOR_GUIDANCE_SCALE: f64 = 8.0; const RESOLUTION_MULTIPLE: f64 = 42.67; const PRIOR_CIN: usize = 16; +const DECODER_CIN: usize = 4; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -331,8 +332,13 @@ fn run(args: Args) -> Result<()> { let weights = weights.deserialize()?; let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); wuerstchen::diffnext::WDiffNeXt::new( - /* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024, - /* clip_embd */ 1024, /* patch_size */ 2, vb, + /* c_in */ DECODER_CIN, + /* c_out */ DECODER_CIN, + /* c_r */ 64, + /* c_cond */ 1024, + /* clip_embd */ 1024, + /* patch_size */ 2, + vb, )? }; @@ -340,7 +346,7 @@ fn run(args: Args) -> Result<()> { let mut latents = Tensor::randn( 0f32, 1f32, - (b_size, PRIOR_CIN, latent_height, latent_width), + (b_size, DECODER_CIN, latent_height, latent_width), &device, )?;