From 693fad511ca4a52040f5c5f4aae1ee8c43d544ed Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 1 Nov 2023 15:37:52 +0100 Subject: [PATCH] Preliminary support for ssd1b. (#1233) --- .../src/models/stable_diffusion/clip.rs | 8 +++ .../src/models/stable_diffusion/mod.rs | 65 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index e7a20270..20e8ceac 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -102,6 +102,14 @@ impl Config { } } + pub fn ssd1b() -> Self { + Self::sdxl() + } + + pub fn ssd1b2() -> Self { + Self::sdxl2() + } + // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json pub fn wuerstchen() -> Self { Self { diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 7fdedaae..66ef7149 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -249,6 +249,71 @@ impl StableDiffusionConfig { ) } + pub fn ssd1b( + sliced_attention_size: Option, + height: Option, + width: Option, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], + center_input_sample: false, + cross_attention_dim: 2048, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = ddim::DDIMSchedulerConfig { + ..Default::default() + }; + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 1024 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 1024 + }; + + Self { + width, + height, + clip: clip::Config::ssd1b(), + clip2: Some(clip::Config::ssd1b2()), + autoencoder, + scheduler, + unet, + } + } + pub fn build_vae>( &self, vae_weights: P,