mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -194,10 +194,16 @@ pub struct JointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: DiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl JointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
|
||||
@ -205,13 +211,15 @@ impl JointBlock {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let (context_attn, x_attn) =
|
||||
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
|
||||
let context_out =
|
||||
self.context_block
|
||||
.post_attention(&context_attn, context, &context_interm)?;
|
||||
@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: QkvOnlyDiTBlock,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl ContextQkvOnlyJointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock {
|
||||
let context_qkv = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
|
||||
|
||||
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
|
||||
Ok(x_out)
|
||||
@ -266,7 +281,28 @@ fn flash_compatible_attention(
|
||||
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
fn joint_attn(
|
||||
context_qkv: &Qkv,
|
||||
x_qkv: &Qkv,
|
||||
num_heads: usize,
|
||||
use_flash_attn: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let qkv = Qkv {
|
||||
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
|
||||
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
|
||||
@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso
|
||||
|
||||
let headdim = qkv.q.dim(D::Minus1)?;
|
||||
let softmax_scale = 1.0 / (headdim as f64).sqrt();
|
||||
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
|
||||
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
|
||||
|
||||
let attn = if use_flash_attn {
|
||||
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
|
||||
} else {
|
||||
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
|
||||
};
|
||||
|
||||
let attn = attn.reshape((batch_size, seqlen, ()))?;
|
||||
let context_qkv_seqlen = context_qkv.q.dim(1)?;
|
||||
|
@ -23,7 +23,7 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn sd3() -> Self {
|
||||
pub fn sd3_medium() -> Self {
|
||||
Self {
|
||||
patch_size: 2,
|
||||
in_channels: 16,
|
||||
@ -49,7 +49,7 @@ pub struct MMDiT {
|
||||
}
|
||||
|
||||
impl MMDiT {
|
||||
pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.head_size * cfg.depth;
|
||||
let core = MMDiTCore::new(
|
||||
cfg.depth,
|
||||
@ -57,6 +57,7 @@ impl MMDiT {
|
||||
cfg.depth,
|
||||
cfg.patch_size,
|
||||
cfg.out_channels,
|
||||
use_flash_attn,
|
||||
vb.clone(),
|
||||
)?;
|
||||
let patch_embedder = PatchEmbedder::new(
|
||||
@ -135,6 +136,7 @@ impl MMDiTCore {
|
||||
num_heads: usize,
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut joint_blocks = Vec::with_capacity(depth - 1);
|
||||
@ -142,6 +144,7 @@ impl MMDiTCore {
|
||||
joint_blocks.push(JointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
vb.pp(format!("joint_blocks.{}", i)),
|
||||
)?);
|
||||
}
|
||||
@ -151,6 +154,7 @@ impl MMDiTCore {
|
||||
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
use_flash_attn,
|
||||
vb.pp(format!("joint_blocks.{}", depth - 1)),
|
||||
)?,
|
||||
final_layer: FinalLayer::new(
|
||||
|
@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections {
|
||||
|
||||
impl QkvOnlyAttnProjections {
|
||||
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
// {'dim': 1536, 'num_heads': 24}
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
Ok(Self { qkv, head_dim })
|
||||
|
@ -467,6 +467,24 @@ pub struct AttentionBlock {
|
||||
config: AttentionBlockConfig,
|
||||
}
|
||||
|
||||
// In the .safetensor weights of official Stable Diffusion 3 Medium Huggingface repo
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-3-medium
|
||||
// Linear layer may use a different dimension for the weight in the linear, which is
|
||||
// incompatible with the current implementation of the nn::linear constructor.
|
||||
// This is a workaround to handle the different dimensions.
|
||||
fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result<nn::Linear> {
|
||||
match vs.get((channels, channels), "weight") {
|
||||
Ok(_) => nn::linear(channels, channels, vs),
|
||||
Err(_) => {
|
||||
let weight = vs
|
||||
.get((channels, channels, 1, 1), "weight")?
|
||||
.reshape((channels, channels))?;
|
||||
let bias = vs.get((channels,), "bias")?;
|
||||
Ok(nn::Linear::new(weight, Some(bias)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
|
||||
let num_head_channels = config.num_head_channels.unwrap_or(channels);
|
||||
@ -478,10 +496,10 @@ impl AttentionBlock {
|
||||
} else {
|
||||
("query", "key", "value", "proj_attn")
|
||||
};
|
||||
let query = nn::linear(channels, channels, vs.pp(q_path))?;
|
||||
let key = nn::linear(channels, channels, vs.pp(k_path))?;
|
||||
let value = nn::linear(channels, channels, vs.pp(v_path))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
|
||||
let query = get_qkv_linear(channels, vs.pp(q_path))?;
|
||||
let key = get_qkv_linear(channels, vs.pp(k_path))?;
|
||||
let value = get_qkv_linear(channels, vs.pp(v_path))?;
|
||||
let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
|
@ -388,6 +388,37 @@ impl ClipTextTransformer {
|
||||
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
||||
self.final_layer_norm.forward(&xs)
|
||||
}
|
||||
|
||||
pub fn forward_until_encoder_layer(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
mask_after: usize,
|
||||
until_layer: isize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (bsz, seq_len) = xs.dims2()?;
|
||||
let xs = self.embeddings.forward(xs)?;
|
||||
let causal_attention_mask =
|
||||
Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
|
||||
|
||||
let mut xs = xs.clone();
|
||||
let mut intermediate = xs.clone();
|
||||
|
||||
// Modified encoder.forward that returns the intermediate tensor along with final output.
|
||||
let until_layer = if until_layer < 0 {
|
||||
self.encoder.layers.len() as isize + until_layer
|
||||
} else {
|
||||
until_layer
|
||||
} as usize;
|
||||
|
||||
for (layer_id, layer) in self.encoder.layers.iter().enumerate() {
|
||||
xs = layer.forward(&xs, &causal_attention_mask)?;
|
||||
if layer_id == until_layer {
|
||||
intermediate = xs.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok((self.final_layer_norm.forward(&xs)?, intermediate))
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipTextTransformer {
|
||||
|
@ -65,6 +65,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
@ -133,6 +135,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
@ -214,6 +218,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
@ -281,6 +287,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(
|
||||
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
|
||||
@ -378,6 +386,8 @@ impl StableDiffusionConfig {
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
};
|
||||
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
|
||||
..Default::default()
|
||||
|
@ -275,6 +275,8 @@ pub struct AutoEncoderKLConfig {
|
||||
pub layers_per_block: usize,
|
||||
pub latent_channels: usize,
|
||||
pub norm_num_groups: usize,
|
||||
pub use_quant_conv: bool,
|
||||
pub use_post_quant_conv: bool,
|
||||
}
|
||||
|
||||
impl Default for AutoEncoderKLConfig {
|
||||
@ -284,6 +286,8 @@ impl Default for AutoEncoderKLConfig {
|
||||
layers_per_block: 1,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
use_quant_conv: true,
|
||||
use_post_quant_conv: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -315,8 +319,8 @@ impl DiagonalGaussianDistribution {
|
||||
pub struct AutoEncoderKL {
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
quant_conv: nn::Conv2d,
|
||||
post_quant_conv: nn::Conv2d,
|
||||
quant_conv: Option<nn::Conv2d>,
|
||||
post_quant_conv: Option<nn::Conv2d>,
|
||||
pub config: AutoEncoderKLConfig,
|
||||
}
|
||||
|
||||
@ -342,20 +346,33 @@ impl AutoEncoderKL {
|
||||
};
|
||||
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
|
||||
let conv_cfg = Default::default();
|
||||
let quant_conv = nn::conv2d(
|
||||
2 * latent_channels,
|
||||
2 * latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("quant_conv"),
|
||||
)?;
|
||||
let post_quant_conv = nn::conv2d(
|
||||
latent_channels,
|
||||
latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("post_quant_conv"),
|
||||
)?;
|
||||
|
||||
let quant_conv = {
|
||||
if config.use_quant_conv {
|
||||
Some(nn::conv2d(
|
||||
2 * latent_channels,
|
||||
2 * latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("quant_conv"),
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
let post_quant_conv = {
|
||||
if config.use_post_quant_conv {
|
||||
Some(nn::conv2d(
|
||||
latent_channels,
|
||||
latent_channels,
|
||||
1,
|
||||
conv_cfg,
|
||||
vs.pp("post_quant_conv"),
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
@ -368,13 +385,19 @@ impl AutoEncoderKL {
|
||||
/// Returns the distribution in the latent space.
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
let parameters = self.quant_conv.forward(&xs)?;
|
||||
let parameters = match &self.quant_conv {
|
||||
None => xs,
|
||||
Some(quant_conv) => quant_conv.forward(&xs)?,
|
||||
};
|
||||
DiagonalGaussianDistribution::new(¶meters)
|
||||
}
|
||||
|
||||
/// Takes as input some sampled values.
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.post_quant_conv.forward(xs)?;
|
||||
self.decoder.forward(&xs)
|
||||
let xs = match &self.post_quant_conv {
|
||||
None => xs,
|
||||
Some(post_quant_conv) => &post_quant_conv.forward(xs)?,
|
||||
};
|
||||
self.decoder.forward(xs)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user