diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index b92fe8fd..2797f4b2 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -156,7 +156,7 @@ fn output_filename( fn encode_prompt( prompt: &str, - uncond_prompt: &str, + uncond_prompt: Option<&str>, tokenizer: std::path::PathBuf, clip_weights: std::path::PathBuf, clip_config: stable_diffusion::clip::Config, @@ -179,24 +179,30 @@ fn encode_prompt( } let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; - let mut uncond_tokens = tokenizer - .encode(uncond_prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - let uncond_tokens_len = uncond_tokens.len(); - while uncond_tokens.len() < clip_config.max_position_embeddings { - uncond_tokens.push(pad_id) - } - let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; - println!("Building the clip transformer."); let text_model = stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?; - let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?; - let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; - Ok(text_embeddings) + match uncond_prompt { + None => Ok(text_embeddings), + Some(uncond_prompt) => { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let uncond_tokens_len = uncond_tokens.len(); + while uncond_tokens.len() < clip_config.max_position_embeddings { + uncond_tokens.push(pad_id) + } + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + + let uncond_embeddings = + text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?; + let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?; + Ok(text_embeddings) + } + } } fn run(args: Args) -> Result<()> { @@ -239,7 +245,7 @@ fn run(args: Args) -> Result<()> { let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?; encode_prompt( &prompt, - &uncond_prompt, + Some(&uncond_prompt), tokenizer.clone(), weights, stable_diffusion::clip::Config::wuerstchen_prior(), @@ -253,7 +259,7 @@ fn run(args: Args) -> Result<()> { let weights = ModelFile::Clip.get(clip_weights)?; encode_prompt( &prompt, - &uncond_prompt, + None, tokenizer.clone(), weights, stable_diffusion::clip::Config::wuerstchen(), @@ -264,15 +270,47 @@ fn run(args: Args) -> Result<()> { println!("Building the prior."); // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json - let prior = { - let prior_weights = ModelFile::Prior.get(prior_weights)?; - let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; - let weights = weights.deserialize()?; - let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - wuerstchen::prior::WPrior::new( - /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280, - /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, - )? + let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; + let b_size = 1; + let effnet = { + let mut latents = Tensor::randn( + 0f32, + 1f32, + (b_size, PRIOR_CIN, latent_height, latent_width), + &device, + )?; + + let prior = { + let prior_weights = ModelFile::Prior.get(prior_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::prior::WPrior::new( + /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280, + /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, + )? + }; + let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; + let timesteps = prior_scheduler.timesteps(); + println!("prior denoising"); + for (index, &t) in timesteps.iter().enumerate() { + let start_time = std::time::Instant::now(); + if index == timesteps.len() - 1 { + continue; + } + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; + let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?; + latents = prior_scheduler.step(&noise_pred, t, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); + } + ((latents * 42.)? - 1.)? }; println!("Building the vqgan."); @@ -298,9 +336,6 @@ fn run(args: Args) -> Result<()> { )? }; - let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize; - let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; - let b_size = 1; for idx in 0..num_samples { let mut latents = Tensor::randn( 0f32, @@ -309,34 +344,9 @@ fn run(args: Args) -> Result<()> { &device, )?; - let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; - let timesteps = prior_scheduler.timesteps(); - println!("prior denoising"); - for (index, &t) in timesteps.iter().enumerate() { - let start_time = std::time::Instant::now(); - if index == timesteps.len() - 1 { - continue; - } - let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; - let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; - let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; - let noise_pred = noise_pred.chunk(2, 0)?; - let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]); - let noise_pred = (noise_pred_uncond - + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?; - latents = prior_scheduler.step(&noise_pred, t, &latents)?; - let dt = start_time.elapsed().as_secs_f32(); - println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); - } - let effnet = ((latents * 42.)? - 1.)?; - let mut latents = Tensor::randn( - 0f32, - 1f32, - (b_size, PRIOR_CIN, latent_height, latent_width), - &device, - )?; - println!("diffusion process"); + let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; + let timesteps = scheduler.timesteps(); for (index, &t) in timesteps.iter().enumerate() { let start_time = std::time::Instant::now(); if index == timesteps.len() - 1 { @@ -344,7 +354,7 @@ fn run(args: Args) -> Result<()> { } let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?; - latents = prior_scheduler.step(&noise_pred, t, &latents)?; + latents = scheduler.step(&noise_pred, t, &latents)?; let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); }