Fix the latent shape.

This commit is contained in:
laurent
2023-09-19 09:21:35 +01:00
parent 576bf7c21f
commit 5d8e214dfe

View File

@ -17,6 +17,7 @@ use tokenizers::Tokenizer;
const PRIOR_GUIDANCE_SCALE: f64 = 8.0; const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
const RESOLUTION_MULTIPLE: f64 = 42.67; const RESOLUTION_MULTIPLE: f64 = 42.67;
const PRIOR_CIN: usize = 16; const PRIOR_CIN: usize = 16;
const DECODER_CIN: usize = 4;
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
@ -331,8 +332,13 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::diffnext::WDiffNeXt::new( wuerstchen::diffnext::WDiffNeXt::new(
/* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024, /* c_in */ DECODER_CIN,
/* clip_embd */ 1024, /* patch_size */ 2, vb, /* c_out */ DECODER_CIN,
/* c_r */ 64,
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
vb,
)? )?
}; };
@ -340,7 +346,7 @@ fn run(args: Args) -> Result<()> {
let mut latents = Tensor::randn( let mut latents = Tensor::randn(
0f32, 0f32,
1f32, 1f32,
(b_size, PRIOR_CIN, latent_height, latent_width), (b_size, DECODER_CIN, latent_height, latent_width),
&device, &device,
)?; )?;