Flash-attention support in stable diffusion (#487)

* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
Laurent Mazare
2023-08-17 12:16:40 +01:00
committed by GitHub
parent 03be33eea4
commit c3176f0dfb
5 changed files with 78 additions and 32 deletions

View File

@ -1,4 +1,3 @@
#![allow(dead_code)]
//! Attention Based Building Blocks
use candle::{IndexOp, Result, Tensor, D};
use candle_nn as nn;
@ -61,6 +60,22 @@ impl FeedForward {
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug)]
struct CrossAttention {
to_q: nn::Linear,
@ -72,6 +87,7 @@ struct CrossAttention {
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
use_flash_attn: bool,
}
impl CrossAttention {
@ -83,6 +99,7 @@ impl CrossAttention {
heads: usize,
dim_head: usize,
slice_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
@ -103,6 +120,7 @@ impl CrossAttention {
slice_size,
span,
span_attn,
use_flash_attn,
})
}
@ -146,8 +164,28 @@ impl CrossAttention {
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
let xs = if self.use_flash_attn {
let init_dtype = query.dtype();
let q = query
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let k = key
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let v = value
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
flash_attn(&q, &k, &v, self.scale as f32, false)?
.transpose(1, 2)?
.squeeze(0)?
.to_dtype(init_dtype)?
} else {
let xs = query.matmul(&(key.t()? * self.scale)?)?;
nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
};
self.reshape_batch_dim_to_heads(&xs)
}
@ -160,15 +198,17 @@ impl CrossAttention {
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
let xs = match self.slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => {
if query.dim(0)? / slice_size <= 1 {
self.attention(&query, &key, &value)?
} else {
self.sliced_attention(&query, &key, &value, slice_size)?
}
let dim0 = query.dim(0)?;
let slice_size = self.slice_size.and_then(|slice_size| {
if dim0 < slice_size {
None
} else {
Some(slice_size)
}
});
let xs = match slice_size {
None => self.attention(&query, &key, &value)?,
Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
};
self.to_out.forward(&xs)
}
@ -194,6 +234,7 @@ impl BasicTransformerBlock {
d_head: usize,
context_dim: Option<usize>,
sliced_attention_size: Option<usize>,
use_flash_attn: bool,
) -> Result<Self> {
let attn1 = CrossAttention::new(
vs.pp("attn1"),
@ -202,6 +243,7 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
let attn2 = CrossAttention::new(
@ -211,6 +253,7 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
use_flash_attn,
)?;
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
@ -279,6 +322,7 @@ impl SpatialTransformer {
in_channels: usize,
n_heads: usize,
d_head: usize,
use_flash_attn: bool,
config: SpatialTransformerConfig,
) -> Result<Self> {
let inner_dim = n_heads * d_head;
@ -304,6 +348,7 @@ impl SpatialTransformer {
d_head,
config.context_dim,
config.sliced_attention_size,
use_flash_attn,
)?;
transformer_blocks.push(tb)
}