Flash-attention support in stable diffusion (#487)

* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
Laurent Mazare
2023-08-17 12:16:40 +01:00
committed by GitHub
parent 03be33eea4
commit c3176f0dfb
5 changed files with 78 additions and 32 deletions

View File

@ -90,6 +90,9 @@ struct Args {
/// Generate intermediary images at each step.
#[arg(long, action)]
intermediary_images: bool,
#[arg(long)]
use_flash_attn: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@ -268,7 +271,7 @@ fn run(args: Args) -> Result<()> {
let vae = sd_config.build_vae(&vae_weights, &device)?;
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?;
let bsize = 1;
for idx in 0..num_samples {