mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Flash-attention support in stable diffusion (#487)
* Add flash-attention for the stable-diffusion example. * Change the dtype. * Silly fix. * Another fix. * Revert the dtype back to the query dtype after apply flash-attn.
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
//! Attention Based Building Blocks
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
@ -61,6 +60,22 @@ impl FeedForward {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
to_q: nn::Linear,
|
||||
@ -72,6 +87,7 @@ struct CrossAttention {
|
||||
slice_size: Option<usize>,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
@ -83,6 +99,7 @@ impl CrossAttention {
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
slice_size: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let context_dim = context_dim.unwrap_or(query_dim);
|
||||
@ -103,6 +120,7 @@ impl CrossAttention {
|
||||
slice_size,
|
||||
span,
|
||||
span_attn,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -146,8 +164,28 @@ impl CrossAttention {
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||
let xs = if self.use_flash_attn {
|
||||
let init_dtype = query.dtype();
|
||||
let q = query
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let k = key
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let v = value
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
flash_attn(&q, &k, &v, self.scale as f32, false)?
|
||||
.transpose(1, 2)?
|
||||
.squeeze(0)?
|
||||
.to_dtype(init_dtype)?
|
||||
} else {
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
|
||||
};
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
@ -160,15 +198,17 @@ impl CrossAttention {
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
let xs = match self.slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => {
|
||||
if query.dim(0)? / slice_size <= 1 {
|
||||
self.attention(&query, &key, &value)?
|
||||
} else {
|
||||
self.sliced_attention(&query, &key, &value, slice_size)?
|
||||
}
|
||||
let dim0 = query.dim(0)?;
|
||||
let slice_size = self.slice_size.and_then(|slice_size| {
|
||||
if dim0 < slice_size {
|
||||
None
|
||||
} else {
|
||||
Some(slice_size)
|
||||
}
|
||||
});
|
||||
let xs = match slice_size {
|
||||
None => self.attention(&query, &key, &value)?,
|
||||
Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
|
||||
};
|
||||
self.to_out.forward(&xs)
|
||||
}
|
||||
@ -194,6 +234,7 @@ impl BasicTransformerBlock {
|
||||
d_head: usize,
|
||||
context_dim: Option<usize>,
|
||||
sliced_attention_size: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<Self> {
|
||||
let attn1 = CrossAttention::new(
|
||||
vs.pp("attn1"),
|
||||
@ -202,6 +243,7 @@ impl BasicTransformerBlock {
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
|
||||
let attn2 = CrossAttention::new(
|
||||
@ -211,6 +253,7 @@ impl BasicTransformerBlock {
|
||||
n_heads,
|
||||
d_head,
|
||||
sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||
@ -279,6 +322,7 @@ impl SpatialTransformer {
|
||||
in_channels: usize,
|
||||
n_heads: usize,
|
||||
d_head: usize,
|
||||
use_flash_attn: bool,
|
||||
config: SpatialTransformerConfig,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = n_heads * d_head;
|
||||
@ -304,6 +348,7 @@ impl SpatialTransformer {
|
||||
d_head,
|
||||
config.context_dim,
|
||||
config.sliced_attention_size,
|
||||
use_flash_attn,
|
||||
)?;
|
||||
transformer_blocks.push(tb)
|
||||
}
|
||||
|
@ -90,6 +90,9 @@ struct Args {
|
||||
/// Generate intermediary images at each step.
|
||||
#[arg(long, action)]
|
||||
intermediary_images: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
@ -268,7 +271,7 @@ fn run(args: Args) -> Result<()> {
|
||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?;
|
||||
|
||||
let bsize = 1;
|
||||
for idx in 0..num_samples {
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::schedulers::PredictionType;
|
||||
use crate::{clip, ddim, unet_2d, vae};
|
||||
use candle::{DType, Device, Result};
|
||||
@ -156,22 +155,6 @@ impl StableDiffusionConfig {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn v2_1_inpaint(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
|
||||
// This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
|
||||
// type being "epsilon" by default and not "v_prediction".
|
||||
Self::v2_1_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
PredictionType::Epsilon,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
vae_weights: P,
|
||||
@ -190,11 +173,18 @@ impl StableDiffusionConfig {
|
||||
unet_weights: P,
|
||||
device: &Device,
|
||||
in_channels: usize,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<unet_2d::UNet2DConditionModel> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
|
||||
let unet = unet_2d::UNet2DConditionModel::new(
|
||||
vs_unet,
|
||||
in_channels,
|
||||
4,
|
||||
use_flash_attn,
|
||||
self.unet.clone(),
|
||||
)?;
|
||||
Ok(unet)
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Denoising Models
|
||||
//!
|
||||
//! The 2D Unet models take as input a noisy sample and the current diffusion
|
||||
@ -103,6 +102,7 @@ impl UNet2DConditionModel {
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
use_flash_attn: bool,
|
||||
config: UNet2DConditionModelConfig,
|
||||
) -> Result<Self> {
|
||||
let n_blocks = config.blocks.len();
|
||||
@ -161,6 +161,7 @@ impl UNet2DConditionModel {
|
||||
in_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
use_flash_attn,
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetDownBlock::CrossAttn(block))
|
||||
@ -190,6 +191,7 @@ impl UNet2DConditionModel {
|
||||
vs.pp("mid_block"),
|
||||
bl_channels,
|
||||
Some(time_embed_dim),
|
||||
use_flash_attn,
|
||||
mid_cfg,
|
||||
)?;
|
||||
|
||||
@ -242,6 +244,7 @@ impl UNet2DConditionModel {
|
||||
prev_out_channels,
|
||||
out_channels,
|
||||
Some(time_embed_dim),
|
||||
use_flash_attn,
|
||||
config,
|
||||
)?;
|
||||
Ok(UNetUpBlock::CrossAttn(block))
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
//! 2D UNet Building Blocks
|
||||
//!
|
||||
use crate::attention::{
|
||||
@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn {
|
||||
vs: nn::VarBuilder,
|
||||
in_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
config: UNetMidBlock2DCrossAttnConfig,
|
||||
) -> Result<Self> {
|
||||
let vs_resnets = vs.pp("resnets");
|
||||
@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn {
|
||||
in_channels,
|
||||
n_heads,
|
||||
in_channels / n_heads,
|
||||
use_flash_attn,
|
||||
attn_cfg,
|
||||
)?;
|
||||
let resnet = ResnetBlock2D::new(
|
||||
@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D {
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
config: CrossAttnDownBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let downblock = DownBlock2D::new(
|
||||
@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D {
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
use_flash_attn,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D {
|
||||
prev_output_channels: usize,
|
||||
out_channels: usize,
|
||||
temb_channels: Option<usize>,
|
||||
use_flash_attn: bool,
|
||||
config: CrossAttnUpBlock2DConfig,
|
||||
) -> Result<Self> {
|
||||
let upblock = UpBlock2D::new(
|
||||
@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D {
|
||||
out_channels,
|
||||
n_heads,
|
||||
out_channels / n_heads,
|
||||
use_flash_attn,
|
||||
cfg,
|
||||
)
|
||||
})
|
||||
|
Reference in New Issue
Block a user