Flash-attention support in stable diffusion (#487)

* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
Laurent Mazare
2023-08-17 12:16:40 +01:00
committed by GitHub
parent 03be33eea4
commit c3176f0dfb
5 changed files with 78 additions and 32 deletions

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
//! Attention Based Building Blocks
use candle::{IndexOp, Result, Tensor, D};
use candle_nn as nn;
@ -61,6 +60,22 @@ impl FeedForward {
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug)]
struct CrossAttention {
to_q: nn::Linear,
@ -72,6 +87,7 @@ struct CrossAttention {
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
use_flash_attn: bool,
}
impl CrossAttention {
@ -83,6 +99,7 @@ impl CrossAttention {
heads: usize,
dim_head: usize,
slice_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
@ -103,6 +120,7 @@ impl CrossAttention {
slice_size,
span,
span_attn,
use_flash_attn,
})
}
@ -146,8 +164,28 @@ impl CrossAttention {
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
let xs = if self.use_flash_attn {
let init_dtype = query.dtype();
let q = query
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let k = key
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let v = value
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
flash_attn(&q, &k, &v, self.scale as f32, false)?
.transpose(1, 2)?
.squeeze(0)?
.to_dtype(init_dtype)?
} else {
let xs = query.matmul(&(key.t()? * self.scale)?)?;
nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
};
self.reshape_batch_dim_to_heads(&xs)
}
@ -160,15 +198,17 @@ impl CrossAttention {
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
let xs = match self.slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => {
if query.dim(0)? / slice_size <= 1 {
self.attention(&query, &key, &value)?
} else {
self.sliced_attention(&query, &key, &value, slice_size)?
}
let dim0 = query.dim(0)?;
let slice_size = self.slice_size.and_then(|slice_size| {
if dim0 < slice_size {
None
} else {
Some(slice_size)
}
});
let xs = match slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
};
self.to_out.forward(&xs)
}
@ -194,6 +234,7 @@ impl BasicTransformerBlock {
d_head: usize,
context_dim: Option<usize>,
sliced_attention_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let attn1 = CrossAttention::new(
vs.pp("attn1"),
@ -202,6 +243,7 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
let attn2 = CrossAttention::new(
@ -211,6 +253,7 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
@ -279,6 +322,7 @@ impl SpatialTransformer {
in_channels: usize,
n_heads: usize,
d_head: usize,
use_flash_attn: bool,
config: SpatialTransformerConfig,
) -> Result<Self> {
let inner_dim = n_heads * d_head;
@ -304,6 +348,7 @@ impl SpatialTransformer {
d_head,
config.context_dim,
config.sliced_attention_size,
use_flash_attn,
)?;
transformer_blocks.push(tb)
}

View File

@ -90,6 +90,9 @@ struct Args {
/// Generate intermediary images at each step.
#[arg(long, action)]
intermediary_images: bool,
#[arg(long)]
use_flash_attn: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@ -268,7 +271,7 @@ fn run(args: Args) -> Result<()> {
let vae = sd_config.build_vae(&vae_weights, &device)?;
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?;
let bsize = 1;
for idx in 0..num_samples {

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
use crate::schedulers::PredictionType;
use crate::{clip, ddim, unet_2d, vae};
use candle::{DType, Device, Result};
@ -156,22 +155,6 @@ impl StableDiffusionConfig {
)
}
pub fn v2_1_inpaint(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
// This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
// type being "epsilon" by default and not "v_prediction".
Self::v2_1_(
sliced_attention_size,
height,
width,
PredictionType::Epsilon,
)
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@ -190,11 +173,18 @@ impl StableDiffusionConfig {
unet_weights: P,
device: &Device,
in_channels: usize,
use_flash_attn: bool,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
let unet = unet_2d::UNet2DConditionModel::new(
vs_unet,
in_channels,
4,
use_flash_attn,
self.unet.clone(),
)?;
Ok(unet)
}

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
//! 2D UNet Denoising Models
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
@ -103,6 +102,7 @@ impl UNet2DConditionModel {
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
use_flash_attn: bool,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
@ -161,6 +161,7 @@ impl UNet2DConditionModel {
in_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
@ -190,6 +191,7 @@ impl UNet2DConditionModel {
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
use_flash_attn,
mid_cfg,
)?;
@ -242,6 +244,7 @@ impl UNet2DConditionModel {
prev_out_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
//! 2D UNet Building Blocks
//!
use crate::attention::{
@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn {
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn {
in_channels,
n_heads,
in_channels / n_heads,
use_flash_attn,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D {
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
use_flash_attn,
cfg,
)
})
@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D {
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
use_flash_attn: bool,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
use_flash_attn,
cfg,
)
})