mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Flash-attention support in stable diffusion (#487)
* Add flash-attention for the stable-diffusion example. * Change the dtype. * Silly fix. * Another fix. * Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Building Blocks
|
||||
//!
|
||||
use crate::attention::{
|
||||
@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn {
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
config: UNetMidBlock2DCrossAttnConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn {
|
||||
in_channels,
|
||||
n_heads,
|
||||
in_channels / n_heads,
|
||||
use_flash_attn,
|
||||
attn_cfg,
|
||||
)?;
|
||||
let resnet = ResnetBlock2D::new(
|
||||
@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D {
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
config: CrossAttnDownBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let downblock = DownBlock2D::new(
|
||||
@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D {
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
use_flash_attn,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D {
|
||||
prev_output_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
config: CrossAttnUpBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let upblock = UpBlock2D::new(
|
||||
@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D {
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
use_flash_attn,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
|
Reference in New Issue
Block a user