mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Flash-attention support in stable diffusion (#487)
* Add flash-attention for the stable-diffusion example. * Change the dtype. * Silly fix. * Another fix. * Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
@ -90,6 +90,9 @@ struct Args {
|
||||
/// Generate intermediary images at each step.
|
||||
#[arg(long, action)]
|
||||
intermediary_images: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
@ -268,7 +271,7 @@ fn run(args: Args) -> Result<()> {
|
||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?;
|
||||
|
||||
let bsize = 1;
|
||||
for idx in 0..num_samples {
|
||||
|
Reference in New Issue
Block a user