W decoding. (#893)

* W decoding.

* Add the diffusion loop.

* Use the appropriate config.
This commit is contained in:
Laurent Mazare
2023-09-19 07:13:44 +01:00
committed by GitHub
parent 92db8cecd3
commit aaa9d4ed6c

View File

@ -248,6 +248,20 @@ fn run(args: Args) -> Result<()> {
}; };
println!("{prior_text_embeddings}"); println!("{prior_text_embeddings}");
let text_embeddings = {
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
let weights = ModelFile::Clip.get(clip_weights)?;
encode_prompt(
&prompt,
&uncond_prompt,
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen(),
&device,
)?
};
println!("{prior_text_embeddings}");
println!("Building the prior."); println!("Building the prior.");
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
let prior = { let prior = {
@ -262,7 +276,7 @@ fn run(args: Args) -> Result<()> {
}; };
println!("Building the vqgan."); println!("Building the vqgan.");
let _vqgan = { let vqgan = {
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?; let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
@ -273,7 +287,7 @@ fn run(args: Args) -> Result<()> {
println!("Building the decoder."); println!("Building the decoder.");
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
let _decoder = { let decoder = {
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?; let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
@ -314,49 +328,37 @@ fn run(args: Args) -> Result<()> {
let dt = start_time.elapsed().as_secs_f32(); let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
} }
let latents = ((latents * 42.)? - 1.)?; let effnet = ((latents * 42.)? - 1.)?;
/* let mut latents = Tensor::randn(
let timesteps = scheduler.timesteps();
let latents = Tensor::randn(
0f32, 0f32,
1f32, 1f32,
(bsize, 4, sd_config.height / 8, sd_config.width / 8), (b_size, PRIOR_CIN, latent_height, latent_width),
&device, &device,
)?; )?;
// scale the initial noise by the standard deviation required by the scheduler
let mut latents = latents * scheduler.init_noise_sigma()?;
println!("starting sampling"); println!("diffusion process");
for (timestep_index, &timestep) in timesteps.iter().enumerate() { for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; if index == timesteps.len() - 1 {
continue;
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; }
let noise_pred = let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
decoder.forward(&latent_model_input, timestep as f64, &text_embeddings)?; let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
let noise_pred = noise_pred.chunk(2, 0)?; latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); let dt = start_time.elapsed().as_secs_f32();
let noise_pred = println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
} }
*/
println!( println!(
"Generating the final image for sample {}/{}.", "Generating the final image for sample {}/{}.",
idx + 1, idx + 1,
num_samples num_samples
); );
/* let image = vqgan.decode(&(&latents * 0.3764)?)?;
let image = vae.decode(&(&latents / 0.18215)?)?;
// TODO: Add the clamping between 0 and 1. // TODO: Add the clamping between 0 and 1.
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None); let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)? candle_examples::save_image(&image, image_filename)?
*/
} }
Ok(()) Ok(())
} }