mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Only use classifier free guidance for the prior. (#896)
* Only use classifier free guidance for the prior. * Add another specific layer-norm structure. * Tweaks. * Fix the latent shape. * Print the prior shape. * More shape fixes. * Remove some debugging continue.
This commit is contained in:
@ -16,7 +16,9 @@ use tokenizers::Tokenizer;
|
||||
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
const LATENT_DIM_SCALE: f64 = 10.67;
|
||||
const PRIOR_CIN: usize = 16;
|
||||
const DECODER_CIN: usize = 4;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -156,7 +158,7 @@ fn output_filename(
|
||||
|
||||
fn encode_prompt(
|
||||
prompt: &str,
|
||||
uncond_prompt: &str,
|
||||
uncond_prompt: Option<&str>,
|
||||
tokenizer: std::path::PathBuf,
|
||||
clip_weights: std::path::PathBuf,
|
||||
clip_config: stable_diffusion::clip::Config,
|
||||
@ -179,24 +181,30 @@ fn encode_prompt(
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let uncond_tokens_len = uncond_tokens.len();
|
||||
while uncond_tokens.len() < clip_config.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the clip transformer.");
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
||||
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
||||
Ok(text_embeddings)
|
||||
match uncond_prompt {
|
||||
None => Ok(text_embeddings),
|
||||
Some(uncond_prompt) => {
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let uncond_tokens_len = uncond_tokens.len();
|
||||
while uncond_tokens.len() < clip_config.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
let uncond_embeddings =
|
||||
text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
@ -239,40 +247,72 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
|
||||
encode_prompt(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
Some(&uncond_prompt),
|
||||
tokenizer.clone(),
|
||||
weights,
|
||||
stable_diffusion::clip::Config::wuerstchen_prior(),
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
println!("{prior_text_embeddings}");
|
||||
println!("generated prior text embeddings {prior_text_embeddings:?}");
|
||||
|
||||
let text_embeddings = {
|
||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
|
||||
let weights = ModelFile::Clip.get(clip_weights)?;
|
||||
encode_prompt(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
None,
|
||||
tokenizer.clone(),
|
||||
weights,
|
||||
stable_diffusion::clip::Config::wuerstchen(),
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
println!("{prior_text_embeddings}");
|
||||
println!("generated text embeddings {text_embeddings:?}");
|
||||
|
||||
println!("Building the prior.");
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
)?
|
||||
let b_size = 1;
|
||||
let image_embeddings = {
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
)?
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
println!("prior denoising");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred = (noise_pred_uncond
|
||||
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
((latents * 42.)? - 1.)?
|
||||
};
|
||||
|
||||
println!("Building the vqgan.");
|
||||
@ -293,58 +333,40 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::diffnext::WDiffNeXt::new(
|
||||
/* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024,
|
||||
/* clip_embd */ 1024, /* patch_size */ 2, vb,
|
||||
/* c_in */ DECODER_CIN,
|
||||
/* c_out */ DECODER_CIN,
|
||||
/* c_r */ 64,
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
|
||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let b_size = 1;
|
||||
for idx in 0..num_samples {
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
|
||||
let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||
let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
(b_size, DECODER_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
println!("prior denoising");
|
||||
println!("diffusion process with prior {image_embeddings:?}");
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = scheduler.timesteps();
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred = (noise_pred_uncond
|
||||
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
let effnet = ((latents * 42.)? - 1.)?;
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
println!("diffusion process");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
|
||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||
let noise_pred =
|
||||
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||
latents = scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
|
Reference in New Issue
Block a user