diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index d40cfbd9..b92fe8fd 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -248,6 +248,20 @@ fn run(args: Args) -> Result<()> { }; println!("{prior_text_embeddings}"); + let text_embeddings = { + let tokenizer = ModelFile::Tokenizer.get(tokenizer)?; + let weights = ModelFile::Clip.get(clip_weights)?; + encode_prompt( + &prompt, + &uncond_prompt, + tokenizer.clone(), + weights, + stable_diffusion::clip::Config::wuerstchen(), + &device, + )? + }; + println!("{prior_text_embeddings}"); + println!("Building the prior."); // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json let prior = { @@ -262,7 +276,7 @@ fn run(args: Args) -> Result<()> { }; println!("Building the vqgan."); - let _vqgan = { + let vqgan = { let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?; let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? }; let weights = weights.deserialize()?; @@ -273,7 +287,7 @@ fn run(args: Args) -> Result<()> { println!("Building the decoder."); // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json - let _decoder = { + let decoder = { let decoder_weights = ModelFile::Decoder.get(decoder_weights)?; let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? }; let weights = weights.deserialize()?; @@ -314,49 +328,37 @@ fn run(args: Args) -> Result<()> { let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); } - let latents = ((latents * 42.)? - 1.)?; - /* - let timesteps = scheduler.timesteps(); - let latents = Tensor::randn( + let effnet = ((latents * 42.)? - 1.)?; + let mut latents = Tensor::randn( 0f32, 1f32, - (bsize, 4, sd_config.height / 8, sd_config.width / 8), + (b_size, PRIOR_CIN, latent_height, latent_width), &device, )?; - // scale the initial noise by the standard deviation required by the scheduler - let mut latents = latents * scheduler.init_noise_sigma()?; - println!("starting sampling"); - for (timestep_index, ×tep) in timesteps.iter().enumerate() { + println!("diffusion process"); + for (index, &t) in timesteps.iter().enumerate() { let start_time = std::time::Instant::now(); - let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; - - let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; - let noise_pred = - decoder.forward(&latent_model_input, timestep as f64, &text_embeddings)?; - let noise_pred = noise_pred.chunk(2, 0)?; - let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); - let noise_pred = - (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; - latents = scheduler.step(&noise_pred, timestep, &latents)?; + if index == timesteps.len() - 1 { + continue; + } + let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; + let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?; + latents = prior_scheduler.step(&noise_pred, t, &latents)?; let dt = start_time.elapsed().as_secs_f32(); - println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); + println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); } - */ - println!( "Generating the final image for sample {}/{}.", idx + 1, num_samples ); - /* - let image = vae.decode(&(&latents / 0.18215)?)?; + let image = vqgan.decode(&(&latents * 0.3764)?)?; // TODO: Add the clamping between 0 and 1. let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, None); candle_examples::save_image(&image, image_filename)? - */ } Ok(()) }