Flash-attention support in stable diffusion (#487)

* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
Laurent Mazare
2023-08-17 12:16:40 +01:00
committed by GitHub
parent 03be33eea4
commit c3176f0dfb
5 changed files with 78 additions and 32 deletions

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
use crate::schedulers::PredictionType;
use crate::{clip, ddim, unet_2d, vae};
use candle::{DType, Device, Result};
@ -156,22 +155,6 @@ impl StableDiffusionConfig {
)
}
pub fn v2_1_inpaint(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
// This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
// type being "epsilon" by default and not "v_prediction".
Self::v2_1_(
sliced_attention_size,
height,
width,
PredictionType::Epsilon,
)
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@ -190,11 +173,18 @@ impl StableDiffusionConfig {
unet_weights: P,
device: &Device,
in_channels: usize,
use_flash_attn: bool,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
let unet = unet_2d::UNet2DConditionModel::new(
vs_unet,
in_channels,
4,
use_flash_attn,
self.unet.clone(),
)?;
Ok(unet)
}