mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Flash-attention support in stable diffusion (#487)
* Add flash-attention for the stable-diffusion example. * Change the dtype. * Silly fix. * Another fix. * Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
@ -61,6 +60,22 @@ impl FeedForward {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
to_q: nn::Linear,
|
||||
@ -72,6 +87,7 @@ struct CrossAttention {
|
||||
slice_size: Option<usize>,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
@ -83,6 +99,7 @@ impl CrossAttention {
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
slice_size: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let context_dim = context_dim.unwrap_or(query_dim);
|
||||
@ -103,6 +120,7 @@ impl CrossAttention {
|
||||
slice_size,
|
||||
span,
|
||||
span_attn,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -146,8 +164,28 @@ impl CrossAttention {
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||
let xs = if self.use_flash_attn {
|
||||
let init_dtype = query.dtype();
|
||||
let q = query
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let k = key
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let v = value
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
flash_attn(&q, &k, &v, self.scale as f32, false)?
|
||||
.transpose(1, 2)?
|
||||
.squeeze(0)?
|
||||
.to_dtype(init_dtype)?
|
||||
} else {
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
|
||||
};
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
@ -160,15 +198,17 @@ impl CrossAttention {
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
let xs = match self.slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => {
|
||||
if query.dim(0)? / slice_size <= 1 {
|
||||
self.attention(&query, &key, &value)?
|
||||
} else {
|
||||
self.sliced_attention(&query, &key, &value, slice_size)?
|
||||
}
|
||||
let dim0 = query.dim(0)?;
|
||||
let slice_size = self.slice_size.and_then(|slice_size| {
|
||||
if dim0 < slice_size {
|
||||
None
|
||||
} else {
|
||||
Some(slice_size)
|
||||
}
|
||||
});
|
||||
let xs = match slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
|
||||
};
|
||||
self.to_out.forward(&xs)
|
||||
}
|
||||
@ -194,6 +234,7 @@ impl BasicTransformerBlock {
|
||||
d_head: usize,
|
||||
context_dim: Option<usize>,
|
||||
sliced_attention_size: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<Self> {
|
||||
let attn1 = CrossAttention::new(
|
||||
vs.pp("attn1"),
|
||||
@ -202,6 +243,7 @@ impl BasicTransformerBlock {
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
|
||||
let attn2 = CrossAttention::new(
|
||||
@ -211,6 +253,7 @@ impl BasicTransformerBlock {
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||
@ -279,6 +322,7 @@ impl SpatialTransformer {
|
||||
in_channels: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
use_flash_attn: bool,
|
||||
config: SpatialTransformerConfig,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = n_heads * d_head;
|
||||
@ -304,6 +348,7 @@ impl SpatialTransformer {
|
||||
d_head,
|
||||
config.context_dim,
|
||||
config.sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
transformer_blocks.push(tb)
|
||||
}
|
||||
|
Reference in New Issue
Block a user