Flash-attention support in stable diffusion (#487)

* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
Laurent Mazare
2023-08-17 12:16:40 +01:00
committed by GitHub
parent 03be33eea4
commit c3176f0dfb
5 changed files with 78 additions and 32 deletions

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
//! 2D UNet Building Blocks
//!
use crate::attention::{
@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn {
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn {
in_channels,
n_heads,
in_channels / n_heads,
use_flash_attn,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D {
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
use_flash_attn,
cfg,
)
})
@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D {
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
use_flash_attn,
cfg,
)
})