Fix the latent shape.

This commit is contained in:
laurent
2023-09-19 09:21:35 +01:00
parent 576bf7c21f
commit 5d8e214dfe

View File

@ -17,6 +17,7 @@ use tokenizers::Tokenizer;
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
const RESOLUTION_MULTIPLE: f64 = 42.67;
const PRIOR_CIN: usize = 16;
const DECODER_CIN: usize = 4;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
@ -331,8 +332,13 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::diffnext::WDiffNeXt::new(
/* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024,
/* clip_embd */ 1024, /* patch_size */ 2, vb,
/* c_in */ DECODER_CIN,
/* c_out */ DECODER_CIN,
/* c_r */ 64,
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
vb,
)?
};
@ -340,7 +346,7 @@ fn run(args: Args) -> Result<()> {
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
(b_size, DECODER_CIN, latent_height, latent_width),
&device,
)?;