From 7299a683534a84e13ec3ff8a92b2fff77102d7e7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 6 Sep 2023 08:06:49 +0200 Subject: [PATCH] img2img pipeline for stable diffusion. (#752) * img2img pipeline for stable diffusion. * Rename the arguments + fix. * Fix for zero strength. * Another fix. * Another fix. * Revert. * Include the backtrace. * Noise scaling. * Fix the height/width. --- candle-core/src/error.rs | 4 +- .../examples/stable-diffusion/ddim.rs | 11 +++ .../examples/stable-diffusion/main.rs | 85 ++++++++++++++++--- candle-examples/src/lib.rs | 8 +- 4 files changed, 91 insertions(+), 17 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 1cf20a84..d030fab1 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -207,11 +207,11 @@ pub type Result = std::result::Result; impl Error { pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Wrapped(Box::new(err)) + Self::Wrapped(Box::new(err)).bt() } pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { - Self::Msg(err.to_string()) + Self::Msg(err.to_string()).bt() } pub fn bt(self) -> Self { diff --git a/candle-examples/examples/stable-diffusion/ddim.rs b/candle-examples/examples/stable-diffusion/ddim.rs index f2e021ce..260a4965 100644 --- a/candle-examples/examples/stable-diffusion/ddim.rs +++ b/candle-examples/examples/stable-diffusion/ddim.rs @@ -163,6 +163,17 @@ impl DDIMScheduler { } } + pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? + } + pub fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 8372edcd..70e1e92c 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -96,6 +96,15 @@ struct Args { #[arg(long)] use_f16: bool, + + #[arg(long, value_name = "FILE")] + img2img: Option, + + /// The strength, indicates how much to transform the initial image. The + /// value must be between 0 and 1, a value of 1 discards the initial image + /// information. + #[arg(long, default_value_t = 0.8)] + img2img_strength: f64, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] @@ -306,6 +315,26 @@ fn text_embeddings( Ok(text_embeddings) } +fn image_preprocess>(path: T) -> anyhow::Result { + let img = image::io::Reader::open(path)?.decode()?; + let (height, width) = (img.height() as usize, img.width() as usize); + let height = height - height % 32; + let width = width - width % 32; + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)? + .unsqueeze(0)?; + Ok(img) +} + fn run(args: Args) -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -328,9 +357,15 @@ fn run(args: Args) -> Result<()> { tracing, use_f16, use_flash_attn, + img2img, + img2img_strength, .. } = args; + if !(0. ..=1.).contains(&img2img_strength) { + anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}") + } + let _guard = if tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -382,25 +417,53 @@ fn run(args: Args) -> Result<()> { println!("Building the autoencoder."); let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?; let vae = sd_config.build_vae(&vae_weights, &device, dtype)?; + let init_latent_dist = match &img2img { + None => None, + Some(image) => { + let image = image_preprocess(image)?.to_device(&device)?; + Some(vae.encode(&image)?) + } + }; println!("Building the unet."); let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?; let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?; + let t_start = if img2img.is_some() { + n_steps - (n_steps as f64 * img2img_strength) as usize + } else { + 0 + }; let bsize = 1; for idx in 0..num_samples { - let mut latents = Tensor::randn( - 0f32, - 1f32, - (bsize, 4, sd_config.height / 8, sd_config.width / 8), - &device, - )? - .to_dtype(dtype)?; - - // scale the initial noise by the standard deviation required by the scheduler - latents = (latents * scheduler.init_noise_sigma())?; + let timesteps = scheduler.timesteps(); + let latents = match &init_latent_dist { + Some(init_latent_dist) => { + let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; + if t_start < timesteps.len() { + let noise = latents.randn_like(0f64, 1f64)?; + scheduler.add_noise(&latents, noise, timesteps[t_start])? + } else { + latents + } + } + None => { + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + (latents * scheduler.init_noise_sigma())? + } + }; + let mut latents = latents.to_dtype(dtype)?; println!("starting sampling"); - for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { + for (timestep_index, ×tep) in timesteps.iter().enumerate() { + if timestep_index < t_start { + continue; + } let start_time = std::time::Instant::now(); let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 395162eb..f9581b02 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -35,14 +35,14 @@ pub fn load_image_and_resize>( } /// Saves an image to disk using the image crate, this expects an input with shape -/// (c, width, height). +/// (c, height, width). pub fn save_image>(img: &Tensor, p: P) -> Result<()> { let p = p.as_ref(); - let (channel, width, height) = img.dims3()?; + let (channel, height, width) = img.dims3()?; if channel != 3 { - candle::bail!("save_image expects an input of shape (3, width, height)") + candle::bail!("save_image expects an input of shape (3, height, width)") } - let img = img.transpose(0, 1)?.t()?.flatten_all()?; + let img = img.permute((1, 2, 0))?.flatten_all()?; let pixels = img.to_vec1::()?; let image: image::ImageBuffer, Vec> = match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {