mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Prior denoising. (#889)
This commit is contained in:
@ -14,7 +14,7 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
const GUIDANCE_SCALE: f64 = 7.5;
|
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||||
const PRIOR_CIN: usize = 16;
|
const PRIOR_CIN: usize = 16;
|
||||||
|
|
||||||
@ -288,16 +288,32 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||||
let b_size = 1;
|
let b_size = 1;
|
||||||
for idx in 0..num_samples {
|
for idx in 0..num_samples {
|
||||||
let latents = Tensor::randn(
|
let mut latents = Tensor::randn(
|
||||||
0f32,
|
0f32,
|
||||||
1f32,
|
1f32,
|
||||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
// TODO: latents denoising loop, use the scheduler values.
|
|
||||||
let ratio = Tensor::ones(1, DType::F32, &device)?;
|
|
||||||
let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?;
|
|
||||||
|
|
||||||
|
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||||
|
let timesteps = prior_scheduler.timesteps();
|
||||||
|
println!("prior denoising");
|
||||||
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
if index == timesteps.len() - 1 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||||
|
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||||
|
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||||
|
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||||
|
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
||||||
|
let noise_pred = (noise_pred_uncond
|
||||||
|
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
||||||
|
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||||
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
|
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||||
|
}
|
||||||
let latents = ((latents * 42.)? - 1.)?;
|
let latents = ((latents * 42.)? - 1.)?;
|
||||||
/*
|
/*
|
||||||
let timesteps = scheduler.timesteps();
|
let timesteps = scheduler.timesteps();
|
||||||
|
@ -38,6 +38,10 @@ impl DDPMWScheduler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn timesteps(&self) -> &[f64] {
|
||||||
|
&self.timesteps
|
||||||
|
}
|
||||||
|
|
||||||
fn alpha_cumprod(&self, t: f64) -> f64 {
|
fn alpha_cumprod(&self, t: f64) -> f64 {
|
||||||
let scaler = self.config.scaler;
|
let scaler = self.config.scaler;
|
||||||
let s = self.config.s;
|
let s = self.config.s;
|
||||||
|
Reference in New Issue
Block a user