Flash-attention support in stable diffusion (#487)

* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
Laurent Mazare
2023-08-17 12:16:40 +01:00
committed by GitHub
parent 03be33eea4
commit c3176f0dfb
5 changed files with 78 additions and 32 deletions

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
//! 2D UNet Denoising Models
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
@ -103,6 +102,7 @@ impl UNet2DConditionModel {
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
use_flash_attn: bool,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
@ -161,6 +161,7 @@ impl UNet2DConditionModel {
in_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
@ -190,6 +191,7 @@ impl UNet2DConditionModel {
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
use_flash_attn,
mid_cfg,
)?;
@ -242,6 +244,7 @@ impl UNet2DConditionModel {
prev_out_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))