diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index ebf0bfcb..2bfb6422 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -5,10 +5,12 @@ extern crate accelerate_src; extern crate intel_mkl_src; use candle_transformers::models::stable_diffusion; +use std::ops::Div; use anyhow::{Error as E, Result}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; +use rand::Rng; use stable_diffusion::vae::AutoEncoderKL; use tokenizers::Tokenizer; @@ -49,6 +51,10 @@ struct Args { #[arg(long, value_name = "FILE")] clip_weights: Option, + /// The CLIP2 weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + clip2_weights: Option, + /// The VAE weight file, in .safetensors format. #[arg(long, value_name = "FILE")] vae_weights: Option, @@ -93,6 +99,11 @@ struct Args { #[arg(long)] guidance_scale: Option, + /// Path to the mask image for inpainting. + #[arg(long, value_name = "FILE")] + mask_path: Option, + + /// Path to the image used to initialize the latents. For inpainting, this is the image to be masked. #[arg(long, value_name = "FILE")] img2img: Option, @@ -105,13 +116,20 @@ struct Args { /// The seed to use when generating random samples. #[arg(long)] seed: Option, + + /// Force the saved image to update only the masked region + #[arg(long)] + only_update_masked: bool, } #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] enum StableDiffusionVersion { V1_5, + V1_5Inpaint, V2_1, + V2Inpaint, Xl, + XlInpaint, Turbo, } @@ -128,16 +146,25 @@ enum ModelFile { impl StableDiffusionVersion { fn repo(&self) -> &'static str { match self { + Self::XlInpaint => "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", + Self::V2Inpaint => "stabilityai/stable-diffusion-2-inpainting", Self::V2_1 => "stabilityai/stable-diffusion-2-1", Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::V1_5Inpaint => "stable-diffusion-v1-5/stable-diffusion-inpainting", Self::Turbo => "stabilityai/sdxl-turbo", } } fn unet_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::V1_5 + | Self::V1_5Inpaint + | Self::V2_1 + | Self::V2Inpaint + | Self::Xl + | Self::XlInpaint + | Self::Turbo => { if use_f16 { "unet/diffusion_pytorch_model.fp16.safetensors" } else { @@ -149,7 +176,13 @@ impl StableDiffusionVersion { fn vae_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::V1_5 + | Self::V1_5Inpaint + | Self::V2_1 + | Self::V2Inpaint + | Self::Xl + | Self::XlInpaint + | Self::Turbo => { if use_f16 { "vae/diffusion_pytorch_model.fp16.safetensors" } else { @@ -161,7 +194,13 @@ impl StableDiffusionVersion { fn clip_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::V1_5 + | Self::V1_5Inpaint + | Self::V2_1 + | Self::V2Inpaint + | Self::Xl + | Self::XlInpaint + | Self::Turbo => { if use_f16 { "text_encoder/model.fp16.safetensors" } else { @@ -173,7 +212,13 @@ impl StableDiffusionVersion { fn clip2_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { + Self::V1_5 + | Self::V1_5Inpaint + | Self::V2_1 + | Self::V2Inpaint + | Self::Xl + | Self::XlInpaint + | Self::Turbo => { if use_f16 { "text_encoder_2/model.fp16.safetensors" } else { @@ -198,10 +243,13 @@ impl ModelFile { let (repo, path) = match self { Self::Tokenizer => { let tokenizer_repo = match version { - StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { - "openai/clip-vit-base-patch32" - } - StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::V1_5Inpaint + | StableDiffusionVersion::V2Inpaint => "openai/clip-vit-base-patch32", + StableDiffusionVersion::Xl + | StableDiffusionVersion::XlInpaint + | StableDiffusionVersion::Turbo => { // This seems similar to the patch32 version except some very small // difference in the split regex. "openai/clip-vit-large-patch14" @@ -299,6 +347,7 @@ fn text_embeddings( uncond_prompt: &str, tokenizer: Option, clip_weights: Option, + clip2_weights: Option, sd_version: StableDiffusionVersion, sd_config: &stable_diffusion::StableDiffusionConfig, use_f16: bool, @@ -342,7 +391,11 @@ fn text_embeddings( } else { ModelFile::Clip2 }; - let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; + let clip_weights = if first { + clip_weights_file.get(clip_weights, sd_version, use_f16)? + } else { + clip_weights_file.get(clip2_weights, sd_version, use_f16)? + }; let clip_config = if first { &sd_config.clip } else { @@ -399,6 +452,82 @@ fn image_preprocess>(path: T) -> anyhow::Result>(path: T) -> anyhow::Result { + let img = image::open(path)?.to_luma8(); + let (new_width, new_height) = { + let (width, height) = img.dimensions(); + (width - width % 32, height - height % 32) + }; + let img = image::imageops::resize( + &img, + new_width, + new_height, + image::imageops::FilterType::CatmullRom, + ) + .into_raw(); + let mask = Tensor::from_vec(img, (new_height as usize, new_width as usize), &Device::Cpu)? + .unsqueeze(0)? + .to_dtype(DType::F32)? + .div(255.0)? + .unsqueeze(0)?; + Ok(mask) +} + +/// Generates the mask latents, scaled mask and mask_4 for inpainting. Returns a tuple of None if inpainting is not +/// being used. +#[allow(clippy::too_many_arguments)] +fn inpainting_tensors( + sd_version: StableDiffusionVersion, + mask_path: Option, + dtype: DType, + device: &Device, + use_guide_scale: bool, + vae: &AutoEncoderKL, + image: Option, + vae_scale: f64, +) -> Result<(Option, Option, Option)> { + match sd_version { + StableDiffusionVersion::XlInpaint + | StableDiffusionVersion::V2Inpaint + | StableDiffusionVersion::V1_5Inpaint => { + let inpaint_mask = mask_path.ok_or_else(|| { + anyhow::anyhow!("An inpainting model was requested but mask-path is not provided.") + })?; + // Get the mask image with shape [1, 1, 128, 128] + let mask = mask_preprocess(inpaint_mask)? + .to_device(device)? + .to_dtype(dtype)?; + // Generate the masked image from the image and the mask with shape [1, 3, 1024, 1024] + let xmask = mask.le(0.5)?.repeat(&[1, 3, 1, 1])?.to_dtype(dtype)?; + let image = &image + .ok_or_else(|| anyhow::anyhow!( + "An inpainting model was requested but img2img which is used as the input image is not provided." + ))?; + let masked_img = (image * xmask)?; + // Scale down the mask + let shape = masked_img.shape(); + let (w, h) = (shape.dims()[3] / 8, shape.dims()[2] / 8); + let mask = mask.interpolate2d(w, h)?; + // shape: [1, 4, 128, 128] + let mask_latents = vae.encode(&masked_img)?; + let mask_latents = (mask_latents.sample()? * vae_scale)?.to_device(device)?; + + let mask_4 = mask.as_ref().repeat(&[1, 4, 1, 1])?; + let (mask_latents, mask) = if use_guide_scale { + ( + Tensor::cat(&[&mask_latents, &mask_latents], 0)?, + Tensor::cat(&[&mask, &mask], 0)?, + ) + } else { + (mask_latents, mask) + }; + Ok((Some(mask_latents), Some(mask), Some(mask_4))) + } + _ => Ok((None, None, None)), + } +} + fn run(args: Args) -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -417,12 +546,14 @@ fn run(args: Args) -> Result<()> { bsize, sd_version, clip_weights, + clip2_weights, vae_weights, unet_weights, tracing, use_f16, guidance_scale, use_flash_attn, + mask_path, img2img, img2img_strength, seed, @@ -445,7 +576,10 @@ fn run(args: Args) -> Result<()> { Some(guidance_scale) => guidance_scale, None => match sd_version { StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V1_5Inpaint | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::V2Inpaint + | StableDiffusionVersion::XlInpaint | StableDiffusionVersion::Xl => 7.5, StableDiffusionVersion::Turbo => 0., }, @@ -454,20 +588,23 @@ fn run(args: Args) -> Result<()> { Some(n_steps) => n_steps, None => match sd_version { StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V1_5Inpaint | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::V2Inpaint + | StableDiffusionVersion::XlInpaint | StableDiffusionVersion::Xl => 30, StableDiffusionVersion::Turbo => 1, }, }; let dtype = if use_f16 { DType::F16 } else { DType::F32 }; let sd_config = match sd_version { - StableDiffusionVersion::V1_5 => { + StableDiffusionVersion::V1_5 | StableDiffusionVersion::V1_5Inpaint => { stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) } - StableDiffusionVersion::V2_1 => { + StableDiffusionVersion::V2_1 | StableDiffusionVersion::V2Inpaint => { stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) } - StableDiffusionVersion::Xl => { + StableDiffusionVersion::Xl | StableDiffusionVersion::XlInpaint => { stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width) } StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( @@ -479,13 +616,16 @@ fn run(args: Args) -> Result<()> { let mut scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; - if let Some(seed) = seed { - device.set_seed(seed)?; - } + // If a seed is not given, generate a random seed and print it + let seed = seed.unwrap_or(rand::thread_rng().gen_range(0u64..u64::MAX)); + println!("Using seed {seed}"); + device.set_seed(seed)?; let use_guide_scale = guidance_scale > 1.0; let which = match sd_version { - StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], + StableDiffusionVersion::Xl + | StableDiffusionVersion::XlInpaint + | StableDiffusionVersion::Turbo => vec![true, false], _ => vec![true], }; let text_embeddings = which @@ -496,6 +636,7 @@ fn run(args: Args) -> Result<()> { &uncond_prompt, tokenizer.clone(), clip_weights.clone(), + clip2_weights.clone(), sd_version, &sd_config, use_f16, @@ -514,16 +655,26 @@ fn run(args: Args) -> Result<()> { println!("Building the autoencoder."); let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?; let vae = sd_config.build_vae(vae_weights, &device, dtype)?; - let init_latent_dist = match &img2img { - None => None, + + let (image, init_latent_dist) = match &img2img { + None => (None, None), Some(image) => { - let image = image_preprocess(image)?.to_device(&device)?; - Some(vae.encode(&image)?) + let image = image_preprocess(image)? + .to_device(&device)? + .to_dtype(dtype)?; + (Some(image.clone()), Some(vae.encode(&image)?)) } }; + println!("Building the unet."); let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?; - let unet = sd_config.build_unet(unet_weights, &device, 4, use_flash_attn, dtype)?; + let in_channels = match sd_version { + StableDiffusionVersion::XlInpaint + | StableDiffusionVersion::V2Inpaint + | StableDiffusionVersion::V1_5Inpaint => 9, + _ => 4, + }; + let unet = sd_config.build_unet(unet_weights, &device, in_channels, use_flash_attn, dtype)?; let t_start = if img2img.is_some() { n_steps - (n_steps as f64 * img2img_strength) as usize @@ -533,11 +684,25 @@ fn run(args: Args) -> Result<()> { let vae_scale = match sd_version { StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V1_5Inpaint | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::V2Inpaint + | StableDiffusionVersion::XlInpaint | StableDiffusionVersion::Xl => 0.18215, StableDiffusionVersion::Turbo => 0.13025, }; + let (mask_latents, mask, mask_4) = inpainting_tensors( + sd_version, + mask_path, + dtype, + &device, + use_guide_scale, + &vae, + image, + vae_scale, + )?; + for idx in 0..num_samples { let timesteps = scheduler.timesteps().to_vec(); let latents = match &init_latent_dist { @@ -576,6 +741,22 @@ fn run(args: Args) -> Result<()> { }; let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; + + let latent_model_input = match sd_version { + StableDiffusionVersion::XlInpaint + | StableDiffusionVersion::V2Inpaint + | StableDiffusionVersion::V1_5Inpaint => Tensor::cat( + &[ + &latent_model_input, + mask.as_ref().unwrap(), + mask_latents.as_ref().unwrap(), + ], + 1, + )?, + _ => latent_model_input, + } + .to_device(&device)?; + let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; @@ -592,6 +773,18 @@ fn run(args: Args) -> Result<()> { let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); + // Replace all pixels in the unmasked region with the original pixels discarding any changes. + if args.only_update_masked { + let mask = mask_4.as_ref().unwrap(); + let latent_to_keep = mask_latents + .as_ref() + .unwrap() + .get_on_dim(0, 0)? // shape: [4, H, W] + .unsqueeze(0)?; // shape: [1, 4, H, W] + + latents = ((&latents * mask)? + &latent_to_keep * (1.0 - mask))?; + } + if args.intermediary_images { save_image( &vae,