mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Use multiple transformer layer in the same cross-attn blocks. (#653)
* Use multiple transformer layer in the same cross-attn blocks. * Make the context contiguous if required.
This commit is contained in:
@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig {
|
||||
pub cross_attn_dim: usize,
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
pub transformer_layers_per_block: usize,
|
||||
}
|
||||
|
||||
impl Default for UNetMidBlock2DCrossAttnConfig {
|
||||
@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
|
||||
cross_attn_dim: 1280,
|
||||
sliced_attention_size: None, // Sliced attention disabled
|
||||
use_linear_projection: false,
|
||||
transformer_layers_per_block: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn {
|
||||
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let attn_cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
depth: config.transformer_layers_per_block,
|
||||
num_groups: resnet_groups,
|
||||
context_dim: Some(config.cross_attn_dim),
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig {
|
||||
// attention_type: "default"
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
pub transformer_layers_per_block: usize,
|
||||
}
|
||||
|
||||
impl Default for CrossAttnDownBlock2DConfig {
|
||||
@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
transformer_layers_per_block: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D {
|
||||
)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
depth: config.transformer_layers_per_block,
|
||||
context_dim: Some(config.cross_attention_dim),
|
||||
num_groups: config.downblock.resnet_groups,
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig {
|
||||
// attention_type: "default"
|
||||
pub sliced_attention_size: Option<usize>,
|
||||
pub use_linear_projection: bool,
|
||||
pub transformer_layers_per_block: usize,
|
||||
}
|
||||
|
||||
impl Default for CrossAttnUpBlock2DConfig {
|
||||
@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig {
|
||||
cross_attention_dim: 1280,
|
||||
sliced_attention_size: None,
|
||||
use_linear_projection: false,
|
||||
transformer_layers_per_block: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D {
|
||||
)?;
|
||||
let n_heads = config.attn_num_head_channels;
|
||||
let cfg = SpatialTransformerConfig {
|
||||
depth: 1,
|
||||
depth: config.transformer_layers_per_block,
|
||||
context_dim: Some(config.cross_attention_dim),
|
||||
num_groups: config.upblock.resnet_groups,
|
||||
sliced_attention_size: config.sliced_attention_size,
|
||||
|
Reference in New Issue
Block a user