mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add the DDPM scheduler. (#877)
* Add the DDPM scheduler. * Minor tweaks.
This commit is contained in:
@ -15,6 +15,7 @@ use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
const GUIDANCE_SCALE: f64 = 7.5;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -217,6 +218,8 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let height = height.unwrap_or(1024);
|
||||
let width = width.unwrap_or(1024);
|
||||
|
||||
let text_embeddings = encode_prompt(
|
||||
&prompt,
|
||||
@ -225,12 +228,12 @@ fn run(args: Args) -> Result<()> {
|
||||
clip_weights.clone(),
|
||||
stable_diffusion::clip::Config::wuerstchen(),
|
||||
&device,
|
||||
);
|
||||
)?;
|
||||
println!("{text_embeddings:?}");
|
||||
|
||||
println!("Building the prior.");
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||
let _prior = {
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
@ -238,7 +241,7 @@ fn run(args: Args) -> Result<()> {
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64,
|
||||
/* depth */ 32, /* nhead */ 24, vb,
|
||||
)
|
||||
)?
|
||||
};
|
||||
|
||||
println!("Building the vqgan.");
|
||||
@ -264,8 +267,21 @@ fn run(args: Args) -> Result<()> {
|
||||
)?
|
||||
};
|
||||
|
||||
let _bsize = 1;
|
||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let b_size = 1;
|
||||
for idx in 0..num_samples {
|
||||
let latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, 4, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
// TODO: latents denoising loop, use the scheduler values.
|
||||
let ratio = Tensor::ones(1, DType::F32, &device)?;
|
||||
let prior = prior.forward(&latents, &ratio, &text_embeddings)?;
|
||||
|
||||
let latents = ((latents * 42.)? - 1.)?;
|
||||
/*
|
||||
let timesteps = scheduler.timesteps();
|
||||
let latents = Tensor::randn(
|
||||
|
Reference in New Issue
Block a user