mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Fix the latent shape.
This commit is contained in:
@ -17,6 +17,7 @@ use tokenizers::Tokenizer;
|
|||||||
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||||
const PRIOR_CIN: usize = 16;
|
const PRIOR_CIN: usize = 16;
|
||||||
|
const DECODER_CIN: usize = 4;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
@ -331,8 +332,13 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
wuerstchen::diffnext::WDiffNeXt::new(
|
wuerstchen::diffnext::WDiffNeXt::new(
|
||||||
/* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024,
|
/* c_in */ DECODER_CIN,
|
||||||
/* clip_embd */ 1024, /* patch_size */ 2, vb,
|
/* c_out */ DECODER_CIN,
|
||||||
|
/* c_r */ 64,
|
||||||
|
/* c_cond */ 1024,
|
||||||
|
/* clip_embd */ 1024,
|
||||||
|
/* patch_size */ 2,
|
||||||
|
vb,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -340,7 +346,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let mut latents = Tensor::randn(
|
let mut latents = Tensor::randn(
|
||||||
0f32,
|
0f32,
|
||||||
1f32,
|
1f32,
|
||||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
(b_size, DECODER_CIN, latent_height, latent_width),
|
||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user