mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use multiple transformer layer in the same cross-attn blocks. (#653)
* Use multiple transformer layer in the same cross-attn blocks. * Make the context contiguous if required.
This commit is contained in:
@ -12,7 +12,9 @@ use candle_nn::Module;
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BlockConfig {
|
||||
pub out_channels: usize,
|
||||
pub use_cross_attn: bool,
|
||||
/// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
|
||||
/// number of transformer blocks to be used.
|
||||
pub use_cross_attn: Option<usize>,
|
||||
pub attention_head_dim: usize,
|
||||
}
|
||||
|
||||
@ -41,22 +43,22 @@ impl Default for UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
BlockConfig {
|
||||
out_channels: 320,
|
||||
use_cross_attn: true,
|
||||
use_cross_attn: Some(1),
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 640,
|
||||
use_cross_attn: true,
|
||||
use_cross_attn: Some(1),
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 1280,
|
||||
use_cross_attn: true,
|
||||
use_cross_attn: Some(1),
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
BlockConfig {
|
||||
out_channels: 1280,
|
||||
use_cross_attn: false,
|
||||
use_cross_attn: None,
|
||||
attention_head_dim: 8,
|
||||
},
|
||||
],
|
||||
@ -149,13 +151,14 @@ impl UNet2DConditionModel {
|
||||
downsample_padding: config.downsample_padding,
|
||||
..Default::default()
|
||||
};
|
||||
if use_cross_attn {
|
||||
if let Some(transformer_layers_per_block) = use_cross_attn {
|
||||
let config = CrossAttnDownBlock2DConfig {
|
||||
downblock: db_cfg,
|
||||
attn_num_head_channels: attention_head_dim,
|
||||
cross_attention_dim: config.cross_attention_dim,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
transformer_layers_per_block,
|
||||
};
|
||||
let block = CrossAttnDownBlock2D::new(
|
||||
vs_db.pp(&i.to_string()),
|
||||
@ -179,6 +182,11 @@ impl UNet2DConditionModel {
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
|
||||
let mid_transformer_layers_per_block = match config.blocks.last() {
|
||||
None => 1,
|
||||
Some(block) => block.use_cross_attn.unwrap_or(1),
|
||||
};
|
||||
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
|
||||
resnet_eps: config.norm_eps,
|
||||
output_scale_factor: config.mid_block_scale_factor,
|
||||
@ -186,8 +194,10 @@ impl UNet2DConditionModel {
|
||||
attn_num_head_channels: bl_attention_head_dim,
|
||||
resnet_groups: Some(config.norm_num_groups),
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
transformer_layers_per_block: mid_transformer_layers_per_block,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mid_block = UNetMidBlock2DCrossAttn::new(
|
||||
vs.pp("mid_block"),
|
||||
bl_channels,
|
||||
@ -231,13 +241,14 @@ impl UNet2DConditionModel {
|
||||
add_upsample: i < n_blocks - 1,
|
||||
..Default::default()
|
||||
};
|
||||
if use_cross_attn {
|
||||
if let Some(transformer_layers_per_block) = use_cross_attn {
|
||||
let config = CrossAttnUpBlock2DConfig {
|
||||
upblock: ub_cfg,
|
||||
attn_num_head_channels: attention_head_dim,
|
||||
cross_attention_dim: config.cross_attention_dim,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: config.use_linear_projection,
|
||||
transformer_layers_per_block,
|
||||
};
|
||||
let block = CrossAttnUpBlock2D::new(
|
||||
vs_ub.pp(&i.to_string()),
|
||||
|
Reference in New Issue
Block a user