From 62ef494dc17c1f582b28c665e78f2aa78d846bb9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 29 Aug 2023 11:13:43 +0100 Subject: [PATCH] Use multiple transformer layer in the same cross-attn blocks. (#653) * Use multiple transformer layer in the same cross-attn blocks. * Make the context contiguous if required. --- .../examples/stable-diffusion/attention.rs | 6 ++--- .../stable-diffusion/stable_diffusion.rs | 22 +++++++++------- .../examples/stable-diffusion/unet_2d.rs | 25 +++++++++++++------ .../stable-diffusion/unet_2d_blocks.rs | 12 ++++++--- 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs index 797542aa..1ae1bfc3 100644 --- a/candle-examples/examples/stable-diffusion/attention.rs +++ b/candle-examples/examples/stable-diffusion/attention.rs @@ -208,9 +208,9 @@ impl CrossAttention { fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result { let _enter = self.span.enter(); let query = self.to_q.forward(xs)?; - let context = context.unwrap_or(xs); - let key = self.to_k.forward(context)?; - let value = self.to_v.forward(context)?; + let context = context.unwrap_or(xs).contiguous()?; + let key = self.to_k.forward(&context)?; + let value = self.to_v.forward(&context)?; let query = self.reshape_heads_to_batch_dim(&query)?; let key = self.reshape_heads_to_batch_dim(&key)?; let value = self.reshape_heads_to_batch_dim(&value)?; diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index bed60161..cffc00d8 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -28,10 +28,10 @@ impl StableDiffusionConfig { // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { blocks: vec![ - bc(320, true, 8), - bc(640, true, 8), - bc(1280, true, 8), - bc(1280, false, 8), + bc(320, Some(1), 8), + bc(640, Some(1), 8), + bc(1280, Some(1), 8), + bc(1280, None, 8), ], center_input_sample: false, cross_attention_dim: 768, @@ -90,10 +90,10 @@ impl StableDiffusionConfig { // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { blocks: vec![ - bc(320, true, 5), - bc(640, true, 10), - bc(1280, true, 20), - bc(1280, false, 20), + bc(320, Some(1), 5), + bc(640, Some(1), 10), + bc(1280, Some(1), 20), + bc(1280, None, 20), ], center_input_sample: false, cross_attention_dim: 1024, @@ -171,7 +171,11 @@ impl StableDiffusionConfig { }; // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { - blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)], + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], center_input_sample: false, cross_attention_dim: 2048, downsample_padding: 1, diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index eb2dbf10..81bd9547 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -12,7 +12,9 @@ use candle_nn::Module; #[derive(Debug, Clone, Copy)] pub struct BlockConfig { pub out_channels: usize, - pub use_cross_attn: bool, + /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the + /// number of transformer blocks to be used. + pub use_cross_attn: Option, pub attention_head_dim: usize, } @@ -41,22 +43,22 @@ impl Default for UNet2DConditionModelConfig { blocks: vec![ BlockConfig { out_channels: 320, - use_cross_attn: true, + use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 640, - use_cross_attn: true, + use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 1280, - use_cross_attn: true, + use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 1280, - use_cross_attn: false, + use_cross_attn: None, attention_head_dim: 8, }, ], @@ -149,13 +151,14 @@ impl UNet2DConditionModel { downsample_padding: config.downsample_padding, ..Default::default() }; - if use_cross_attn { + if let Some(transformer_layers_per_block) = use_cross_attn { let config = CrossAttnDownBlock2DConfig { downblock: db_cfg, attn_num_head_channels: attention_head_dim, cross_attention_dim: config.cross_attention_dim, sliced_attention_size, use_linear_projection: config.use_linear_projection, + transformer_layers_per_block, }; let block = CrossAttnDownBlock2D::new( vs_db.pp(&i.to_string()), @@ -179,6 +182,11 @@ impl UNet2DConditionModel { }) .collect::>>()?; + // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462 + let mid_transformer_layers_per_block = match config.blocks.last() { + None => 1, + Some(block) => block.use_cross_attn.unwrap_or(1), + }; let mid_cfg = UNetMidBlock2DCrossAttnConfig { resnet_eps: config.norm_eps, output_scale_factor: config.mid_block_scale_factor, @@ -186,8 +194,10 @@ impl UNet2DConditionModel { attn_num_head_channels: bl_attention_head_dim, resnet_groups: Some(config.norm_num_groups), use_linear_projection: config.use_linear_projection, + transformer_layers_per_block: mid_transformer_layers_per_block, ..Default::default() }; + let mid_block = UNetMidBlock2DCrossAttn::new( vs.pp("mid_block"), bl_channels, @@ -231,13 +241,14 @@ impl UNet2DConditionModel { add_upsample: i < n_blocks - 1, ..Default::default() }; - if use_cross_attn { + if let Some(transformer_layers_per_block) = use_cross_attn { let config = CrossAttnUpBlock2DConfig { upblock: ub_cfg, attn_num_head_channels: attention_head_dim, cross_attention_dim: config.cross_attention_dim, sliced_attention_size, use_linear_projection: config.use_linear_projection, + transformer_layers_per_block, }; let block = CrossAttnUpBlock2D::new( vs_ub.pp(&i.to_string()), diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 65341e74..1db65222 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig { pub cross_attn_dim: usize, pub sliced_attention_size: Option, pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, } impl Default for UNetMidBlock2DCrossAttnConfig { @@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig { cross_attn_dim: 1280, sliced_attention_size: None, // Sliced attention disabled use_linear_projection: false, + transformer_layers_per_block: 1, } } } @@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn { let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; let n_heads = config.attn_num_head_channels; let attn_cfg = SpatialTransformerConfig { - depth: 1, + depth: config.transformer_layers_per_block, num_groups: resnet_groups, context_dim: Some(config.cross_attn_dim), sliced_attention_size: config.sliced_attention_size, @@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig { // attention_type: "default" pub sliced_attention_size: Option, pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, } impl Default for CrossAttnDownBlock2DConfig { @@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig { cross_attention_dim: 1280, sliced_attention_size: None, use_linear_projection: false, + transformer_layers_per_block: 1, } } } @@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D { )?; let n_heads = config.attn_num_head_channels; let cfg = SpatialTransformerConfig { - depth: 1, + depth: config.transformer_layers_per_block, context_dim: Some(config.cross_attention_dim), num_groups: config.downblock.resnet_groups, sliced_attention_size: config.sliced_attention_size, @@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig { // attention_type: "default" pub sliced_attention_size: Option, pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, } impl Default for CrossAttnUpBlock2DConfig { @@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig { cross_attention_dim: 1280, sliced_attention_size: None, use_linear_projection: false, + transformer_layers_per_block: 1, } } } @@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D { )?; let n_heads = config.attn_num_head_channels; let cfg = SpatialTransformerConfig { - depth: 1, + depth: config.transformer_layers_per_block, context_dim: Some(config.cross_attention_dim), num_groups: config.upblock.resnet_groups, sliced_attention_size: config.sliced_attention_size,