mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Flash-attention support in stable diffusion (#487)
* Add flash-attention for the stable-diffusion example. * Change the dtype. * Silly fix. * Another fix. * Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Denoising Models
|
||||
//!
|
||||
//! The 2D Unet models take as input a noisy sample and the current diffusion
|
||||
@ -103,6 +102,7 @@ impl UNet2DConditionModel {
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
use_flash_attn: bool,
|
||||
config: UNet2DConditionModelConfig,
|
||||
) -> Result<Self> {
|
||||
let n_blocks = config.blocks.len();
|
||||
@ -161,6 +161,7 @@ impl UNet2DConditionModel {
|
||||
in_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
use_flash_attn,
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetDownBlock::CrossAttn(block))
|
||||
@ -190,6 +191,7 @@ impl UNet2DConditionModel {
|
||||
vs.pp("mid_block"),
|
||||
bl_channels,
|
||||
Some(time_embed_dim),
|
||||
use_flash_attn,
|
||||
mid_cfg,
|
||||
)?;
|
||||
|
||||
@ -242,6 +244,7 @@ impl UNet2DConditionModel {
|
||||
prev_out_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
use_flash_attn,
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetUpBlock::CrossAttn(block))
|
||||
|
Reference in New Issue
Block a user