mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Only use classifier free guidance for the prior.
This commit is contained in:
@ -156,7 +156,7 @@ fn output_filename(
|
|||||||
|
|
||||||
fn encode_prompt(
|
fn encode_prompt(
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
uncond_prompt: &str,
|
uncond_prompt: Option<&str>,
|
||||||
tokenizer: std::path::PathBuf,
|
tokenizer: std::path::PathBuf,
|
||||||
clip_weights: std::path::PathBuf,
|
clip_weights: std::path::PathBuf,
|
||||||
clip_config: stable_diffusion::clip::Config,
|
clip_config: stable_diffusion::clip::Config,
|
||||||
@ -179,6 +179,13 @@ fn encode_prompt(
|
|||||||
}
|
}
|
||||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
|
println!("Building the clip transformer.");
|
||||||
|
let text_model =
|
||||||
|
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||||
|
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
||||||
|
match uncond_prompt {
|
||||||
|
None => Ok(text_embeddings),
|
||||||
|
Some(uncond_prompt) => {
|
||||||
let mut uncond_tokens = tokenizer
|
let mut uncond_tokens = tokenizer
|
||||||
.encode(uncond_prompt, true)
|
.encode(uncond_prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
@ -190,13 +197,12 @@ fn encode_prompt(
|
|||||||
}
|
}
|
||||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
println!("Building the clip transformer.");
|
let uncond_embeddings =
|
||||||
let text_model =
|
text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
|
||||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
|
||||||
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
|
||||||
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
||||||
Ok(text_embeddings)
|
Ok(text_embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run(args: Args) -> Result<()> {
|
fn run(args: Args) -> Result<()> {
|
||||||
@ -239,7 +245,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
|
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
|
||||||
encode_prompt(
|
encode_prompt(
|
||||||
&prompt,
|
&prompt,
|
||||||
&uncond_prompt,
|
Some(&uncond_prompt),
|
||||||
tokenizer.clone(),
|
tokenizer.clone(),
|
||||||
weights,
|
weights,
|
||||||
stable_diffusion::clip::Config::wuerstchen_prior(),
|
stable_diffusion::clip::Config::wuerstchen_prior(),
|
||||||
@ -253,7 +259,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let weights = ModelFile::Clip.get(clip_weights)?;
|
let weights = ModelFile::Clip.get(clip_weights)?;
|
||||||
encode_prompt(
|
encode_prompt(
|
||||||
&prompt,
|
&prompt,
|
||||||
&uncond_prompt,
|
None,
|
||||||
tokenizer.clone(),
|
tokenizer.clone(),
|
||||||
weights,
|
weights,
|
||||||
stable_diffusion::clip::Config::wuerstchen(),
|
stable_diffusion::clip::Config::wuerstchen(),
|
||||||
@ -264,6 +270,17 @@ fn run(args: Args) -> Result<()> {
|
|||||||
|
|
||||||
println!("Building the prior.");
|
println!("Building the prior.");
|
||||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||||
|
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||||
|
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||||
|
let b_size = 1;
|
||||||
|
let effnet = {
|
||||||
|
let mut latents = Tensor::randn(
|
||||||
|
0f32,
|
||||||
|
1f32,
|
||||||
|
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
|
||||||
let prior = {
|
let prior = {
|
||||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||||
@ -274,6 +291,27 @@ fn run(args: Args) -> Result<()> {
|
|||||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||||
|
let timesteps = prior_scheduler.timesteps();
|
||||||
|
println!("prior denoising");
|
||||||
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
if index == timesteps.len() - 1 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||||
|
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||||
|
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||||
|
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||||
|
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
||||||
|
let noise_pred = (noise_pred_uncond
|
||||||
|
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
||||||
|
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||||
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
|
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||||
|
}
|
||||||
|
((latents * 42.)? - 1.)?
|
||||||
|
};
|
||||||
|
|
||||||
println!("Building the vqgan.");
|
println!("Building the vqgan.");
|
||||||
let vqgan = {
|
let vqgan = {
|
||||||
@ -298,9 +336,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
|
||||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
|
||||||
let b_size = 1;
|
|
||||||
for idx in 0..num_samples {
|
for idx in 0..num_samples {
|
||||||
let mut latents = Tensor::randn(
|
let mut latents = Tensor::randn(
|
||||||
0f32,
|
0f32,
|
||||||
@ -309,34 +344,9 @@ fn run(args: Args) -> Result<()> {
|
|||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
|
||||||
let timesteps = prior_scheduler.timesteps();
|
|
||||||
println!("prior denoising");
|
|
||||||
for (index, &t) in timesteps.iter().enumerate() {
|
|
||||||
let start_time = std::time::Instant::now();
|
|
||||||
if index == timesteps.len() - 1 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
|
||||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
|
||||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
|
||||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
|
||||||
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
|
||||||
let noise_pred = (noise_pred_uncond
|
|
||||||
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
|
||||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
|
||||||
let dt = start_time.elapsed().as_secs_f32();
|
|
||||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
|
||||||
}
|
|
||||||
let effnet = ((latents * 42.)? - 1.)?;
|
|
||||||
let mut latents = Tensor::randn(
|
|
||||||
0f32,
|
|
||||||
1f32,
|
|
||||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
println!("diffusion process");
|
println!("diffusion process");
|
||||||
|
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||||
|
let timesteps = scheduler.timesteps();
|
||||||
for (index, &t) in timesteps.iter().enumerate() {
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
if index == timesteps.len() - 1 {
|
if index == timesteps.len() - 1 {
|
||||||
@ -344,7 +354,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
}
|
}
|
||||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||||
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
|
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
|
||||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
latents = scheduler.step(&noise_pred, t, &latents)?;
|
||||||
let dt = start_time.elapsed().as_secs_f32();
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user