Use multiple transformer layer in the same cross-attn blocks. (#653)

* Use multiple transformer layer in the same cross-attn blocks.

* Make the context contiguous if required.
This commit is contained in:
Laurent Mazare
2023-08-29 11:13:43 +01:00
committed by GitHub
parent d0a330448d
commit 62ef494dc1
4 changed files with 43 additions and 22 deletions

View File

@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig {
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
pub transformer_layers_per_block: usize,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
transformer_layers_per_block: 1,
}
}
}
@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn {
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
depth: 1,
depth: config.transformer_layers_per_block,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnDownBlock2DConfig {
@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
transformer_layers_per_block: 1,
}
}
}
@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnUpBlock2DConfig {
@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
transformer_layers_per_block: 1,
}
}
}
@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,