mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -194,10 +194,16 @@ pub struct JointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: DiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl JointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
|
||||
@ -205,13 +211,15 @@ impl JointBlock {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let (context_attn, x_attn) =
|
||||
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
|
||||
let context_out =
|
||||
self.context_block
|
||||
.post_attention(&context_attn, context, &context_interm)?;
|
||||
@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: QkvOnlyDiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl ContextQkvOnlyJointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock {
|
||||
let context_qkv = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
|
||||
|
||||
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
|
||||
Ok(x_out)
|
||||
@ -266,7 +281,28 @@ fn flash_compatible_attention(
|
||||
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
fn joint_attn(
|
||||
context_qkv: &Qkv,
|
||||
x_qkv: &Qkv,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let qkv = Qkv {
|
||||
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
|
||||
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
|
||||
@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso
|
||||
|
||||
let headdim = qkv.q.dim(D::Minus1)?;
|
||||
let softmax_scale = 1.0 / (headdim as f64).sqrt();
|
||||
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
|
||||
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
|
||||
|
||||
let attn = if use_flash_attn {
|
||||
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
|
||||
} else {
|
||||
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
|
||||
};
|
||||
|
||||
let attn = attn.reshape((batch_size, seqlen, ()))?;
|
||||
let context_qkv_seqlen = context_qkv.q.dim(1)?;
|
||||
|
Reference in New Issue
Block a user