Add Stable Diffusion 3 Example (#2558)

* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Czxck001
2024-10-13 13:08:40 -07:00
committed by GitHub
parent 0d96ec31e8
commit ca7cf5cb3b
16 changed files with 751 additions and 34 deletions

View File

@ -194,10 +194,16 @@ pub struct JointBlock {
x_block: DiTBlock,
context_block: DiTBlock,
num_heads: usize,
use_flash_attn: bool,
}
impl JointBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
pub fn new(
hidden_size: usize,
num_heads: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
@ -205,13 +211,15 @@ impl JointBlock {
x_block,
context_block,
num_heads,
use_flash_attn,
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
let (context_attn, x_attn) =
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let context_out =
self.context_block
.post_attention(&context_attn, context, &context_interm)?;
@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock {
x_block: DiTBlock,
context_block: QkvOnlyDiTBlock,
num_heads: usize,
use_flash_attn: bool,
}
impl ContextQkvOnlyJointBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
pub fn new(
hidden_size: usize,
num_heads: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
Ok(Self {
x_block,
context_block,
num_heads,
use_flash_attn,
})
}
@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock {
let context_qkv = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
Ok(x_out)
@ -266,7 +281,28 @@ fn flash_compatible_attention(
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
}
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
fn joint_attn(
context_qkv: &Qkv,
x_qkv: &Qkv,
num_heads: usize,
use_flash_attn: bool,
) -> Result<(Tensor, Tensor)> {
let qkv = Qkv {
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
let attn = if use_flash_attn {
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
} else {
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
};
let attn = attn.reshape((batch_size, seqlen, ()))?;
let context_qkv_seqlen = context_qkv.q.dim(1)?;