From db3e9dae048db2272ec6e4478f8f503c4b6745b6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 17 Sep 2023 13:46:38 +0200 Subject: [PATCH] Wuerstchen main (#876) * Wuerstchen main. * More of the wuerstchen cli example. * Paella creation. * Build the prior model. * Fix the weight file names. --- candle-examples/examples/wuerstchen/main.rs | 359 +++++------------- .../src/models/stable_diffusion/clip.rs | 15 + .../src/models/wuerstchen/paella_vq.rs | 112 +++++- 3 files changed, 213 insertions(+), 273 deletions(-) diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index c8b771a0..32c7d158 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -1,3 +1,5 @@ +#![allow(unused)] + #[cfg(feature = "accelerate")] extern crate accelerate_src; @@ -5,6 +7,7 @@ extern crate accelerate_src; extern crate intel_mkl_src; use candle_transformers::models::stable_diffusion; +use candle_transformers::models::wuerstchen; use anyhow::{Error as E, Result}; use candle::{DType, Device, IndexOp, Module, Tensor, D}; @@ -42,17 +45,21 @@ struct Args { #[arg(long)] width: Option, - /// The UNet weight file, in .safetensors format. + /// The decoder weight file, in .safetensors format. #[arg(long, value_name = "FILE")] - unet_weights: Option, + decoder_weights: Option, /// The CLIP weight file, in .safetensors format. #[arg(long, value_name = "FILE")] clip_weights: Option, - /// The VAE weight file, in .safetensors format. + /// The prior weight file, in .safetensors format. #[arg(long, value_name = "FILE")] - vae_weights: Option, + prior_weights: Option, + + /// The VQGAN weight file, in .safetensors format. + #[arg(long, value_name = "FILE")] + vqgan_weights: Option, #[arg(long, value_name = "FILE")] /// The file specifying the tokenizer to used for tokenization. @@ -73,138 +80,31 @@ struct Args { /// The name of the final image to generate. #[arg(long, value_name = "FILE", default_value = "sd_final.png")] final_image: String, - - #[arg(long, value_enum, default_value = "v2-1")] - sd_version: StableDiffusionVersion, - - /// Generate intermediary images at each step. - #[arg(long, action)] - intermediary_images: bool, - - #[arg(long)] - use_flash_attn: bool, - - #[arg(long)] - use_f16: bool, - - #[arg(long, value_name = "FILE")] - img2img: Option, - - /// The strength, indicates how much to transform the initial image. The - /// value must be between 0 and 1, a value of 1 discards the initial image - /// information. - #[arg(long, default_value_t = 0.8)] - img2img_strength: f64, } -#[derive(Debug, Clone, Copy, clap::ValueEnum)] -enum StableDiffusionVersion { - V1_5, - V2_1, - Xl, -} - -#[allow(unused)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ModelFile { Tokenizer, - Tokenizer2, Clip, - Clip2, - Unet, - Vae, -} - -impl StableDiffusionVersion { - fn repo(&self) -> &'static str { - match self { - Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", - Self::V2_1 => "stabilityai/stable-diffusion-2-1", - Self::V1_5 => "runwayml/stable-diffusion-v1-5", - } - } - - fn unet_file(&self, use_f16: bool) -> &'static str { - match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { - if use_f16 { - "unet/diffusion_pytorch_model.fp16.safetensors" - } else { - "unet/diffusion_pytorch_model.safetensors" - } - } - } - } - - fn vae_file(&self, use_f16: bool) -> &'static str { - match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { - if use_f16 { - "vae/diffusion_pytorch_model.fp16.safetensors" - } else { - "vae/diffusion_pytorch_model.safetensors" - } - } - } - } - - fn clip_file(&self, use_f16: bool) -> &'static str { - match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { - if use_f16 { - "text_encoder/model.fp16.safetensors" - } else { - "text_encoder/model.safetensors" - } - } - } - } - - fn clip2_file(&self, use_f16: bool) -> &'static str { - match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { - if use_f16 { - "text_encoder_2/model.fp16.safetensors" - } else { - "text_encoder_2/model.safetensors" - } - } - } - } + Decoder, + VqGan, + Prior, } impl ModelFile { - fn get( - &self, - filename: Option, - version: StableDiffusionVersion, - use_f16: bool, - ) -> Result { + fn get(&self, filename: Option) -> Result { use hf_hub::api::sync::Api; match filename { Some(filename) => Ok(std::path::PathBuf::from(filename)), None => { + let repo_main = "warp-ai/wuerstchen"; + let repo_prior = "warp-ai/wuerstchen-prior"; let (repo, path) = match self { - Self::Tokenizer => { - let tokenizer_repo = match version { - StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { - "openai/clip-vit-base-patch32" - } - StableDiffusionVersion::Xl => { - // This seems similar to the patch32 version except some very small - // difference in the split regex. - "openai/clip-vit-large-patch14" - } - }; - (tokenizer_repo, "tokenizer.json") - } - Self::Tokenizer2 => { - ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") - } - Self::Clip => (version.repo(), version.clip_file(use_f16)), - Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), - Self::Unet => (version.repo(), version.unet_file(use_f16)), - Self::Vae => (version.repo(), version.vae_file(use_f16)), + Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"), + Self::Clip => (repo_main, "text_encoder/model.safetensors"), + Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"), + Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"), + Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"), }; let filename = Api::new()?.model(repo.to_string()).get(path)?; Ok(filename) @@ -240,27 +140,17 @@ fn output_filename( } } -#[allow(clippy::too_many_arguments)] -fn text_embeddings( +fn encode_prompt( prompt: &str, uncond_prompt: &str, tokenizer: Option, clip_weights: Option, - sd_version: StableDiffusionVersion, - sd_config: &stable_diffusion::StableDiffusionConfig, - use_f16: bool, + clip_config: stable_diffusion::clip::Config, device: &Device, - dtype: DType, - first: bool, ) -> Result { - let tokenizer_file = if first { - ModelFile::Tokenizer - } else { - ModelFile::Tokenizer2 - }; - let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?; + let tokenizer = ModelFile::Tokenizer.get(tokenizer)?; let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; - let pad_id = match &sd_config.clip.pad_with { + let pad_id = match &clip_config.pad_with { Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), }; @@ -270,7 +160,7 @@ fn text_embeddings( .map_err(E::msg)? .get_ids() .to_vec(); - while tokens.len() < sd_config.clip.max_position_embeddings { + while tokens.len() < clip_config.max_position_embeddings { tokens.push(pad_id) } let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; @@ -280,51 +170,21 @@ fn text_embeddings( .map_err(E::msg)? .get_ids() .to_vec(); - while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + while uncond_tokens.len() < clip_config.max_position_embeddings { uncond_tokens.push(pad_id) } let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; println!("Building the Clip transformer."); - let clip_weights_file = if first { - ModelFile::Clip - } else { - ModelFile::Clip2 - }; - let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?; - let clip_config = if first { - &sd_config.clip - } else { - sd_config.clip2.as_ref().unwrap() - }; + let clip_weights = ModelFile::Clip.get(clip_weights)?; let text_model = - stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?; + stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; let text_embeddings = text_model.forward(&tokens)?; let uncond_embeddings = text_model.forward(&uncond_tokens)?; - let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?; + let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; Ok(text_embeddings) } -fn image_preprocess>(path: T) -> anyhow::Result { - let img = image::io::Reader::open(path)?.decode()?; - let (height, width) = (img.height() as usize, img.width() as usize); - let height = height - height % 32; - let width = width - width % 32; - let img = img.resize_to_fill( - width as u32, - height as u32, - image::imageops::FilterType::CatmullRom, - ); - let img = img.to_rgb8(); - let img = img.into_raw(); - let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? - .permute((2, 0, 1))? - .to_dtype(DType::F32)? - .affine(2. / 255., -1.)? - .unsqueeze(0)?; - Ok(img) -} - fn run(args: Args) -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -340,22 +200,14 @@ fn run(args: Args) -> Result<()> { final_image, sliced_attention_size, num_samples, - sd_version, clip_weights, - vae_weights, - unet_weights, + prior_weights, + vqgan_weights, + decoder_weights, tracing, - use_f16, - use_flash_attn, - img2img, - img2img_strength, .. } = args; - if !(0. ..=1.).contains(&img2img_strength) { - anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}") - } - let _guard = if tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -364,102 +216,75 @@ fn run(args: Args) -> Result<()> { None }; - let dtype = if use_f16 { DType::F16 } else { DType::F32 }; - let sd_config = match sd_version { - StableDiffusionVersion::V1_5 => { - stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) - } - StableDiffusionVersion::V2_1 => { - stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) - } - StableDiffusionVersion::Xl => { - stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width) - } - }; - - let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; - let which = match sd_version { - StableDiffusionVersion::Xl => vec![true, false], - _ => vec![true], - }; - let text_embeddings = which - .iter() - .map(|first| { - text_embeddings( - &prompt, - &uncond_prompt, - tokenizer.clone(), - clip_weights.clone(), - sd_version, - &sd_config, - use_f16, - &device, - dtype, - *first, - ) - }) - .collect::>>()?; - let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; + let text_embeddings = encode_prompt( + &prompt, + &uncond_prompt, + tokenizer.clone(), + clip_weights.clone(), + stable_diffusion::clip::Config::wuerstchen(), + &device, + ); println!("{text_embeddings:?}"); - println!("Building the autoencoder."); - let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?; - let vae = sd_config.build_vae(&vae_weights, &device, dtype)?; - let init_latent_dist = match &img2img { - None => None, - Some(image) => { - let image = image_preprocess(image)?.to_device(&device)?; - Some(vae.encode(&image)?) - } + println!("Building the prior."); + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json + let _prior = { + let prior_weights = ModelFile::Prior.get(prior_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::prior::WPrior::new( + /* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64, + /* depth */ 32, /* nhead */ 24, vb, + ) }; - println!("Building the unet."); - let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?; - let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?; - let t_start = if img2img.is_some() { - n_steps - (n_steps as f64 * img2img_strength) as usize - } else { - 0 + println!("Building the vqgan."); + let _vqgan = { + let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::paella_vq::PaellaVQ::new(vb)? }; - let bsize = 1; + + println!("Building the decoder."); + + // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json + let _decoder = { + let decoder_weights = ModelFile::Decoder.get(decoder_weights)?; + let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? }; + let weights = weights.deserialize()?; + let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + wuerstchen::diffnext::WDiffNeXt::new( + /* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024, + /* clip_embd */ 1024, /* patch_size */ 2, vb, + )? + }; + + let _bsize = 1; for idx in 0..num_samples { + /* let timesteps = scheduler.timesteps(); - let latents = match &init_latent_dist { - Some(init_latent_dist) => { - let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; - if t_start < timesteps.len() { - let noise = latents.randn_like(0f64, 1f64)?; - scheduler.add_noise(&latents, noise, timesteps[t_start])? - } else { - latents - } - } - None => { - let latents = Tensor::randn( - 0f32, - 1f32, - (bsize, 4, sd_config.height / 8, sd_config.width / 8), - &device, - )?; - // scale the initial noise by the standard deviation required by the scheduler - (latents * scheduler.init_noise_sigma())? - } - }; - let mut latents = latents.to_dtype(dtype)?; + let latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + // scale the initial noise by the standard deviation required by the scheduler + let mut latents = latents * scheduler.init_noise_sigma()?; println!("starting sampling"); for (timestep_index, ×tep) in timesteps.iter().enumerate() { - if timestep_index < t_start { - continue; - } let start_time = std::time::Instant::now(); let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let noise_pred = - unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + decoder.forward(&latent_model_input, timestep as f64, &text_embeddings)?; let noise_pred = noise_pred.chunk(2, 0)?; let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); let noise_pred = @@ -467,28 +292,22 @@ fn run(args: Args) -> Result<()> { latents = scheduler.step(&noise_pred, timestep, &latents)?; let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); - - if args.intermediary_images { - let image = vae.decode(&(&latents / 0.18215)?)?; - let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; - let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; - let image_filename = - output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); - candle_examples::save_image(&image, image_filename)? - } } + */ println!( "Generating the final image for sample {}/{}.", idx + 1, num_samples ); + /* let image = vae.decode(&(&latents / 0.18215)?)?; // TODO: Add the clamping between 0 and 1. let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, None); candle_examples::save_image(&image, image_filename)? + */ } Ok(()) } diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 397a1cef..31d025b3 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -99,6 +99,21 @@ impl Config { activation: Activation::Gelu, } } + + // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json + pub fn wuerstchen() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1024, + intermediate_size: 4096, + max_position_embeddings: 77, + pad_with: Some("!".to_string()), + num_hidden_layers: 24, + num_attention_heads: 16, + projection_dim: 1024, + activation: Activation::Gelu, + } + } } // CLIP Text Model diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 1268047a..a60f8e8a 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -65,17 +65,121 @@ impl Module for MixingResidualBlock { } #[derive(Debug)] -struct PaellaVQ { +pub struct PaellaVQ { in_block_conv: candle_nn::Conv2d, out_block_conv: candle_nn::Conv2d, down_blocks: Vec<(Option, MixingResidualBlock)>, down_blocks_conv: candle_nn::Conv2d, down_blocks_bn: candle_nn::BatchNorm, up_blocks_conv: candle_nn::Conv2d, - up_blocks: Vec<(MixingResidualBlock, Option)>, + up_blocks: Vec<(Vec, Option)>, } impl PaellaVQ { + pub fn new(vb: VarBuilder) -> Result { + const IN_CHANNELS: usize = 3; + const OUT_CHANNELS: usize = 3; + const LATENT_CHANNELS: usize = 4; + const EMBED_DIM: usize = 384; + const BOTTLENECK_BLOCKS: usize = 12; + const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM]; + + let in_block_conv = candle_nn::conv2d( + IN_CHANNELS * 4, + C_LEVELS[0], + 1, + Default::default(), + vb.pp("in_block.1"), + )?; + let out_block_conv = candle_nn::conv2d( + C_LEVELS[0], + OUT_CHANNELS * 4, + 1, + Default::default(), + vb.pp("out_block.0"), + )?; + + let mut down_blocks = Vec::new(); + let vb_d = vb.pp("down_blocks"); + let mut d_idx = 0; + for (i, &c_level) in C_LEVELS.iter().enumerate() { + let conv_block = if i > 0 { + let cfg = candle_nn::Conv2dConfig { + padding: 1, + stride: 2, + ..Default::default() + }; + let block = + candle_nn::conv2d_no_bias(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?; + d_idx += 1; + Some(block) + } else { + None + }; + let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?; + d_idx += 1; + down_blocks.push((conv_block, res_block)) + } + let down_blocks_conv = candle_nn::conv2d_no_bias( + C_LEVELS[1], + LATENT_CHANNELS, + 1, + Default::default(), + vb_d.pp(d_idx), + )?; + d_idx += 1; + let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(d_idx))?; + + let mut up_blocks = Vec::new(); + let vb_u = vb.pp("up_blocks"); + let mut u_idx = 0; + let up_blocks_conv = candle_nn::conv2d_no_bias( + LATENT_CHANNELS, + C_LEVELS[1], + 1, + Default::default(), + vb_u.pp(u_idx), + )?; + u_idx += 1; + for (i, &c_level) in C_LEVELS.iter().rev().enumerate() { + let mut res_blocks = Vec::new(); + let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 }; + for _j in 0..n_bottleneck_blocks { + let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?; + u_idx += 1; + res_blocks.push(res_block) + } + let conv_block = if i < C_LEVELS.len() - 1 { + let cfg = candle_nn::ConvTranspose2dConfig { + padding: 1, + stride: 2, + ..Default::default() + }; + let block = candle_nn::conv_transpose2d_no_bias( + c_level, + C_LEVELS[i - 1], + 4, + cfg, + vb_u.pp(u_idx), + )?; + u_idx += 1; + Some(block) + } else { + None + }; + up_blocks.push((res_blocks, conv_block)) + } + Ok(Self { + in_block_conv, + down_blocks, + down_blocks_conv, + down_blocks_bn, + up_blocks, + up_blocks_conv, + out_block_conv, + }) + } + pub fn encode(&self, xs: &Tensor) -> Result { let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?; for down_block in self.down_blocks.iter() { @@ -92,7 +196,9 @@ impl PaellaVQ { // TODO: quantizer if we want to support `force_not_quantize=False`. let mut xs = xs.apply(&self.up_blocks_conv)?; for up_block in self.up_blocks.iter() { - xs = xs.apply(&up_block.0)?; + for b in up_block.0.iter() { + xs = xs.apply(b)?; + } if let Some(conv) = &up_block.1 { xs = xs.apply(conv)? }