Only use classifier free guidance for the prior.

This commit is contained in:
laurent
2023-09-19 08:40:02 +01:00
parent 9cf26c5cff
commit b936e32e11

View File

@ -156,7 +156,7 @@ fn output_filename(
fn encode_prompt( fn encode_prompt(
prompt: &str, prompt: &str,
uncond_prompt: &str, uncond_prompt: Option<&str>,
tokenizer: std::path::PathBuf, tokenizer: std::path::PathBuf,
clip_weights: std::path::PathBuf, clip_weights: std::path::PathBuf,
clip_config: stable_diffusion::clip::Config, clip_config: stable_diffusion::clip::Config,
@ -179,6 +179,13 @@ fn encode_prompt(
} }
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the clip transformer.");
let text_model =
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
match uncond_prompt {
None => Ok(text_embeddings),
Some(uncond_prompt) => {
let mut uncond_tokens = tokenizer let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true) .encode(uncond_prompt, true)
.map_err(E::msg)? .map_err(E::msg)?
@ -190,14 +197,13 @@ fn encode_prompt(
} }
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the clip transformer."); let uncond_embeddings =
let text_model = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
Ok(text_embeddings) Ok(text_embeddings)
} }
}
}
fn run(args: Args) -> Result<()> { fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
@ -239,7 +245,7 @@ fn run(args: Args) -> Result<()> {
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?; let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
encode_prompt( encode_prompt(
&prompt, &prompt,
&uncond_prompt, Some(&uncond_prompt),
tokenizer.clone(), tokenizer.clone(),
weights, weights,
stable_diffusion::clip::Config::wuerstchen_prior(), stable_diffusion::clip::Config::wuerstchen_prior(),
@ -253,7 +259,7 @@ fn run(args: Args) -> Result<()> {
let weights = ModelFile::Clip.get(clip_weights)?; let weights = ModelFile::Clip.get(clip_weights)?;
encode_prompt( encode_prompt(
&prompt, &prompt,
&uncond_prompt, None,
tokenizer.clone(), tokenizer.clone(),
weights, weights,
stable_diffusion::clip::Config::wuerstchen(), stable_diffusion::clip::Config::wuerstchen(),
@ -264,6 +270,17 @@ fn run(args: Args) -> Result<()> {
println!("Building the prior."); println!("Building the prior.");
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let b_size = 1;
let effnet = {
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
let prior = { let prior = {
let prior_weights = ModelFile::Prior.get(prior_weights)?; let prior_weights = ModelFile::Prior.get(prior_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
@ -274,6 +291,27 @@ fn run(args: Args) -> Result<()> {
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
)? )?
}; };
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
println!("prior denoising");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
let noise_pred = (noise_pred_uncond
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
((latents * 42.)? - 1.)?
};
println!("Building the vqgan."); println!("Building the vqgan.");
let vqgan = { let vqgan = {
@ -298,9 +336,6 @@ fn run(args: Args) -> Result<()> {
)? )?
}; };
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let b_size = 1;
for idx in 0..num_samples { for idx in 0..num_samples {
let mut latents = Tensor::randn( let mut latents = Tensor::randn(
0f32, 0f32,
@ -309,34 +344,9 @@ fn run(args: Args) -> Result<()> {
&device, &device,
)?; )?;
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
println!("prior denoising");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
let noise_pred = (noise_pred_uncond
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
let effnet = ((latents * 42.)? - 1.)?;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
println!("diffusion process"); println!("diffusion process");
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = scheduler.timesteps();
for (index, &t) in timesteps.iter().enumerate() { for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 { if index == timesteps.len() - 1 {
@ -344,7 +354,7 @@ fn run(args: Args) -> Result<()> {
} }
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?; let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?; latents = scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32(); let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
} }