mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Fix the latent shape.
This commit is contained in:
@ -17,6 +17,7 @@ use tokenizers::Tokenizer;
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
const PRIOR_CIN: usize = 16;
|
||||
const DECODER_CIN: usize = 4;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -331,8 +332,13 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::diffnext::WDiffNeXt::new(
|
||||
/* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024,
|
||||
/* clip_embd */ 1024, /* patch_size */ 2, vb,
|
||||
/* c_in */ DECODER_CIN,
|
||||
/* c_out */ DECODER_CIN,
|
||||
/* c_r */ 64,
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
|
||||
@ -340,7 +346,7 @@ fn run(args: Args) -> Result<()> {
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
(b_size, DECODER_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
|
Reference in New Issue
Block a user