mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use multiple transformer layer in the same cross-attn blocks. (#653)
* Use multiple transformer layer in the same cross-attn blocks. * Make the context contiguous if required.
This commit is contained in:
@ -208,9 +208,9 @@ impl CrossAttention {
|
|||||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let query = self.to_q.forward(xs)?;
|
let query = self.to_q.forward(xs)?;
|
||||||
let context = context.unwrap_or(xs);
|
let context = context.unwrap_or(xs).contiguous()?;
|
||||||
let key = self.to_k.forward(context)?;
|
let key = self.to_k.forward(&context)?;
|
||||||
let value = self.to_v.forward(context)?;
|
let value = self.to_v.forward(&context)?;
|
||||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||||
|
@ -28,10 +28,10 @@ impl StableDiffusionConfig {
|
|||||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
|
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
|
||||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||||
blocks: vec![
|
blocks: vec![
|
||||||
bc(320, true, 8),
|
bc(320, Some(1), 8),
|
||||||
bc(640, true, 8),
|
bc(640, Some(1), 8),
|
||||||
bc(1280, true, 8),
|
bc(1280, Some(1), 8),
|
||||||
bc(1280, false, 8),
|
bc(1280, None, 8),
|
||||||
],
|
],
|
||||||
center_input_sample: false,
|
center_input_sample: false,
|
||||||
cross_attention_dim: 768,
|
cross_attention_dim: 768,
|
||||||
@ -90,10 +90,10 @@ impl StableDiffusionConfig {
|
|||||||
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
|
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
|
||||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||||
blocks: vec![
|
blocks: vec![
|
||||||
bc(320, true, 5),
|
bc(320, Some(1), 5),
|
||||||
bc(640, true, 10),
|
bc(640, Some(1), 10),
|
||||||
bc(1280, true, 20),
|
bc(1280, Some(1), 20),
|
||||||
bc(1280, false, 20),
|
bc(1280, None, 20),
|
||||||
],
|
],
|
||||||
center_input_sample: false,
|
center_input_sample: false,
|
||||||
cross_attention_dim: 1024,
|
cross_attention_dim: 1024,
|
||||||
@ -171,7 +171,11 @@ impl StableDiffusionConfig {
|
|||||||
};
|
};
|
||||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
|
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
|
||||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||||
blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)],
|
blocks: vec![
|
||||||
|
bc(320, None, 5),
|
||||||
|
bc(640, Some(2), 10),
|
||||||
|
bc(1280, Some(10), 20),
|
||||||
|
],
|
||||||
center_input_sample: false,
|
center_input_sample: false,
|
||||||
cross_attention_dim: 2048,
|
cross_attention_dim: 2048,
|
||||||
downsample_padding: 1,
|
downsample_padding: 1,
|
||||||
|
@ -12,7 +12,9 @@ use candle_nn::Module;
|
|||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct BlockConfig {
|
pub struct BlockConfig {
|
||||||
pub out_channels: usize,
|
pub out_channels: usize,
|
||||||
pub use_cross_attn: bool,
|
/// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
|
||||||
|
/// number of transformer blocks to be used.
|
||||||
|
pub use_cross_attn: Option<usize>,
|
||||||
pub attention_head_dim: usize,
|
pub attention_head_dim: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,22 +43,22 @@ impl Default for UNet2DConditionModelConfig {
|
|||||||
blocks: vec![
|
blocks: vec![
|
||||||
BlockConfig {
|
BlockConfig {
|
||||||
out_channels: 320,
|
out_channels: 320,
|
||||||
use_cross_attn: true,
|
use_cross_attn: Some(1),
|
||||||
attention_head_dim: 8,
|
attention_head_dim: 8,
|
||||||
},
|
},
|
||||||
BlockConfig {
|
BlockConfig {
|
||||||
out_channels: 640,
|
out_channels: 640,
|
||||||
use_cross_attn: true,
|
use_cross_attn: Some(1),
|
||||||
attention_head_dim: 8,
|
attention_head_dim: 8,
|
||||||
},
|
},
|
||||||
BlockConfig {
|
BlockConfig {
|
||||||
out_channels: 1280,
|
out_channels: 1280,
|
||||||
use_cross_attn: true,
|
use_cross_attn: Some(1),
|
||||||
attention_head_dim: 8,
|
attention_head_dim: 8,
|
||||||
},
|
},
|
||||||
BlockConfig {
|
BlockConfig {
|
||||||
out_channels: 1280,
|
out_channels: 1280,
|
||||||
use_cross_attn: false,
|
use_cross_attn: None,
|
||||||
attention_head_dim: 8,
|
attention_head_dim: 8,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -149,13 +151,14 @@ impl UNet2DConditionModel {
|
|||||||
downsample_padding: config.downsample_padding,
|
downsample_padding: config.downsample_padding,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
if use_cross_attn {
|
if let Some(transformer_layers_per_block) = use_cross_attn {
|
||||||
let config = CrossAttnDownBlock2DConfig {
|
let config = CrossAttnDownBlock2DConfig {
|
||||||
downblock: db_cfg,
|
downblock: db_cfg,
|
||||||
attn_num_head_channels: attention_head_dim,
|
attn_num_head_channels: attention_head_dim,
|
||||||
cross_attention_dim: config.cross_attention_dim,
|
cross_attention_dim: config.cross_attention_dim,
|
||||||
sliced_attention_size,
|
sliced_attention_size,
|
||||||
use_linear_projection: config.use_linear_projection,
|
use_linear_projection: config.use_linear_projection,
|
||||||
|
transformer_layers_per_block,
|
||||||
};
|
};
|
||||||
let block = CrossAttnDownBlock2D::new(
|
let block = CrossAttnDownBlock2D::new(
|
||||||
vs_db.pp(&i.to_string()),
|
vs_db.pp(&i.to_string()),
|
||||||
@ -179,6 +182,11 @@ impl UNet2DConditionModel {
|
|||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
// https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
|
||||||
|
let mid_transformer_layers_per_block = match config.blocks.last() {
|
||||||
|
None => 1,
|
||||||
|
Some(block) => block.use_cross_attn.unwrap_or(1),
|
||||||
|
};
|
||||||
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
|
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
|
||||||
resnet_eps: config.norm_eps,
|
resnet_eps: config.norm_eps,
|
||||||
output_scale_factor: config.mid_block_scale_factor,
|
output_scale_factor: config.mid_block_scale_factor,
|
||||||
@ -186,8 +194,10 @@ impl UNet2DConditionModel {
|
|||||||
attn_num_head_channels: bl_attention_head_dim,
|
attn_num_head_channels: bl_attention_head_dim,
|
||||||
resnet_groups: Some(config.norm_num_groups),
|
resnet_groups: Some(config.norm_num_groups),
|
||||||
use_linear_projection: config.use_linear_projection,
|
use_linear_projection: config.use_linear_projection,
|
||||||
|
transformer_layers_per_block: mid_transformer_layers_per_block,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mid_block = UNetMidBlock2DCrossAttn::new(
|
let mid_block = UNetMidBlock2DCrossAttn::new(
|
||||||
vs.pp("mid_block"),
|
vs.pp("mid_block"),
|
||||||
bl_channels,
|
bl_channels,
|
||||||
@ -231,13 +241,14 @@ impl UNet2DConditionModel {
|
|||||||
add_upsample: i < n_blocks - 1,
|
add_upsample: i < n_blocks - 1,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
if use_cross_attn {
|
if let Some(transformer_layers_per_block) = use_cross_attn {
|
||||||
let config = CrossAttnUpBlock2DConfig {
|
let config = CrossAttnUpBlock2DConfig {
|
||||||
upblock: ub_cfg,
|
upblock: ub_cfg,
|
||||||
attn_num_head_channels: attention_head_dim,
|
attn_num_head_channels: attention_head_dim,
|
||||||
cross_attention_dim: config.cross_attention_dim,
|
cross_attention_dim: config.cross_attention_dim,
|
||||||
sliced_attention_size,
|
sliced_attention_size,
|
||||||
use_linear_projection: config.use_linear_projection,
|
use_linear_projection: config.use_linear_projection,
|
||||||
|
transformer_layers_per_block,
|
||||||
};
|
};
|
||||||
let block = CrossAttnUpBlock2D::new(
|
let block = CrossAttnUpBlock2D::new(
|
||||||
vs_ub.pp(&i.to_string()),
|
vs_ub.pp(&i.to_string()),
|
||||||
|
@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig {
|
|||||||
pub cross_attn_dim: usize,
|
pub cross_attn_dim: usize,
|
||||||
pub sliced_attention_size: Option<usize>,
|
pub sliced_attention_size: Option<usize>,
|
||||||
pub use_linear_projection: bool,
|
pub use_linear_projection: bool,
|
||||||
|
pub transformer_layers_per_block: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for UNetMidBlock2DCrossAttnConfig {
|
impl Default for UNetMidBlock2DCrossAttnConfig {
|
||||||
@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
|
|||||||
cross_attn_dim: 1280,
|
cross_attn_dim: 1280,
|
||||||
sliced_attention_size: None, // Sliced attention disabled
|
sliced_attention_size: None, // Sliced attention disabled
|
||||||
use_linear_projection: false,
|
use_linear_projection: false,
|
||||||
|
transformer_layers_per_block: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn {
|
|||||||
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
|
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
|
||||||
let n_heads = config.attn_num_head_channels;
|
let n_heads = config.attn_num_head_channels;
|
||||||
let attn_cfg = SpatialTransformerConfig {
|
let attn_cfg = SpatialTransformerConfig {
|
||||||
depth: 1,
|
depth: config.transformer_layers_per_block,
|
||||||
num_groups: resnet_groups,
|
num_groups: resnet_groups,
|
||||||
context_dim: Some(config.cross_attn_dim),
|
context_dim: Some(config.cross_attn_dim),
|
||||||
sliced_attention_size: config.sliced_attention_size,
|
sliced_attention_size: config.sliced_attention_size,
|
||||||
@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig {
|
|||||||
// attention_type: "default"
|
// attention_type: "default"
|
||||||
pub sliced_attention_size: Option<usize>,
|
pub sliced_attention_size: Option<usize>,
|
||||||
pub use_linear_projection: bool,
|
pub use_linear_projection: bool,
|
||||||
|
pub transformer_layers_per_block: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for CrossAttnDownBlock2DConfig {
|
impl Default for CrossAttnDownBlock2DConfig {
|
||||||
@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
|
|||||||
cross_attention_dim: 1280,
|
cross_attention_dim: 1280,
|
||||||
sliced_attention_size: None,
|
sliced_attention_size: None,
|
||||||
use_linear_projection: false,
|
use_linear_projection: false,
|
||||||
|
transformer_layers_per_block: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D {
|
|||||||
)?;
|
)?;
|
||||||
let n_heads = config.attn_num_head_channels;
|
let n_heads = config.attn_num_head_channels;
|
||||||
let cfg = SpatialTransformerConfig {
|
let cfg = SpatialTransformerConfig {
|
||||||
depth: 1,
|
depth: config.transformer_layers_per_block,
|
||||||
context_dim: Some(config.cross_attention_dim),
|
context_dim: Some(config.cross_attention_dim),
|
||||||
num_groups: config.downblock.resnet_groups,
|
num_groups: config.downblock.resnet_groups,
|
||||||
sliced_attention_size: config.sliced_attention_size,
|
sliced_attention_size: config.sliced_attention_size,
|
||||||
@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig {
|
|||||||
// attention_type: "default"
|
// attention_type: "default"
|
||||||
pub sliced_attention_size: Option<usize>,
|
pub sliced_attention_size: Option<usize>,
|
||||||
pub use_linear_projection: bool,
|
pub use_linear_projection: bool,
|
||||||
|
pub transformer_layers_per_block: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for CrossAttnUpBlock2DConfig {
|
impl Default for CrossAttnUpBlock2DConfig {
|
||||||
@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig {
|
|||||||
cross_attention_dim: 1280,
|
cross_attention_dim: 1280,
|
||||||
sliced_attention_size: None,
|
sliced_attention_size: None,
|
||||||
use_linear_projection: false,
|
use_linear_projection: false,
|
||||||
|
transformer_layers_per_block: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D {
|
|||||||
)?;
|
)?;
|
||||||
let n_heads = config.attn_num_head_channels;
|
let n_heads = config.attn_num_head_channels;
|
||||||
let cfg = SpatialTransformerConfig {
|
let cfg = SpatialTransformerConfig {
|
||||||
depth: 1,
|
depth: config.transformer_layers_per_block,
|
||||||
context_dim: Some(config.cross_attention_dim),
|
context_dim: Some(config.cross_attention_dim),
|
||||||
num_groups: config.upblock.resnet_groups,
|
num_groups: config.upblock.resnet_groups,
|
||||||
sliced_attention_size: config.sliced_attention_size,
|
sliced_attention_size: config.sliced_attention_size,
|
||||||
|
Reference in New Issue
Block a user