Support sd3.5 medium and MMDiT-X (#2587)

* extract attn out of joint_attn

* further adjust attn and joint_attn

* add mmdit-x support

* support sd3.5-medium in the example

* update README.md
This commit is contained in:
Czxck001
2024-10-29 22:19:07 -07:00
committed by GitHub
parent 139ff56aeb
commit d232e132f6
4 changed files with 276 additions and 42 deletions

View File

@ -1,8 +1,8 @@
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium # candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3/3.5
![](assets/stable-diffusion-3.jpg) ![](assets/stable-diffusion-3.jpg)
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k* *A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*, generated by Stable Diffusion 3 Medium
Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture. Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.
@ -10,9 +10,17 @@ Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion
- [research paper](https://arxiv.org/pdf/2403.03206) - [research paper](https://arxiv.org/pdf/2403.03206)
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium) - [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)
Stable Diffusion 3.5 is a family of text-to-image models with latest improvements:
- [announcement blog post](https://stability.ai/news/introducing-stable-diffusion-3-5)
It has three variants:
- [Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) @ 8.1b params, with scaled and slightly modified MMDiT architecture.
- [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) distilled version that enables 4-step inference.
- [Stable Diffusion 3.5 Medium](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) @ 2.5b params, with improved MMDiT-X architecture.
## Getting access to the weights ## Getting access to the weights
The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting [the repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account. The weights of Stable Diffusion 3/3.5 is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the repos on HuggingFace Hub to gain access to the weights for your HuggingFace account.
To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli): To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli):
@ -27,10 +35,12 @@ On the first run, the weights will be automatically downloaded from the Huggingf
```shell ```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- \ cargo run --example stable-diffusion-3 --release --features=cuda -- \
--height 1024 --width 1024 \ --which 3-medium --height 1024 --width 1024 \
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k' --prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
``` ```
To use different models, changed the value of `--which` option. (Possible values: `3-medium`, `3.5-large`, `3.5-large-turbo` and `3.5-medium`).
To display other options available, To display other options available,
```shell ```shell
@ -45,7 +55,7 @@ cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- -
## Performance Benchmark ## Performance Benchmark
Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds). Below benchmark is done with Stable Diffusion 3 Medium by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc). [candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).

View File

@ -19,13 +19,15 @@ enum Which {
V3_5Large, V3_5Large,
#[value(name = "3.5-large-turbo")] #[value(name = "3.5-large-turbo")]
V3_5LargeTurbo, V3_5LargeTurbo,
#[value(name = "3.5-medium")]
V3_5Medium,
} }
impl Which { impl Which {
fn is_3_5(&self) -> bool { fn is_3_5(&self) -> bool {
match self { match self {
Self::V3Medium => false, Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo => true, Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true,
} }
} }
} }
@ -117,36 +119,59 @@ fn main() -> Result<()> {
let default_inference_steps = match which { let default_inference_steps = match which {
Which::V3_5Large => 28, Which::V3_5Large => 28,
Which::V3_5LargeTurbo => 4, Which::V3_5LargeTurbo => 4,
Which::V3_5Medium => 28,
Which::V3Medium => 28, Which::V3Medium => 28,
}; };
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps); let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which { let default_cfg_scale = match which {
Which::V3_5Large => 4.0, Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0, Which::V3_5LargeTurbo => 1.0,
Which::V3_5Medium => 4.0,
Which::V3Medium => 4.0, Which::V3Medium => 4.0,
}; };
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale); let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let (mmdit_config, mut triple, vb) = if which.is_3_5() { let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo = { let sai_repo_for_text_encoders = {
let name = match which { let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large", Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo", Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
// Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually
// placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo.
// To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors
// under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions
// to get the monolithic text encoders. This is not a trivial task.
// Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium,
// which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors.
// so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository.
// TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders.
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large",
Which::V3Medium => unreachable!(), Which::V3Medium => unreachable!(),
}; };
api.repo(hf_hub::Repo::model(name.to_string())) api.repo(hf_hub::Repo::model(name.to_string()))
}; };
let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?; let sai_repo_for_mmdit = {
let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?; let name = match which {
let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?; Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = { let model_file = {
let model_file = match which { let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors", Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors", Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3_5Medium => "sd3.5_medium.safetensors",
Which::V3Medium => unreachable!(), Which::V3Medium => unreachable!(),
}; };
sai_repo.get(model_file)? sai_repo_for_mmdit.get(model_file)?
}; };
let triple = StableDiffusion3TripleClipWithTokenizer::new_split( let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file, &clip_g_file,
@ -157,7 +182,12 @@ fn main() -> Result<()> {
let vb = unsafe { let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)? candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
}; };
(MMDiTConfig::sd3_5_large(), triple, vb) match which {
Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb),
Which::V3Medium => unreachable!(),
}
} else { } else {
let sai_repo = { let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium"; let name = "stabilityai/stable-diffusion-3-medium";

View File

@ -36,7 +36,6 @@ impl Module for LayerNormNoAffine {
impl DiTBlock { impl DiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> { pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
// {'hidden_size': 1536, 'num_heads': 24}
let norm1 = LayerNormNoAffine::new(1e-6); let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?; let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let norm2 = LayerNormNoAffine::new(1e-6); let norm2 = LayerNormNoAffine::new(1e-6);
@ -103,6 +102,117 @@ impl DiTBlock {
} }
} }
pub struct SelfAttnModulateIntermediates {
gate_msa: Tensor,
shift_mlp: Tensor,
scale_mlp: Tensor,
gate_mlp: Tensor,
gate_msa2: Tensor,
}
pub struct SelfAttnDiTBlock {
norm1: LayerNormNoAffine,
attn: AttnProjections,
attn2: AttnProjections,
norm2: LayerNormNoAffine,
mlp: Mlp,
ada_ln_modulation: nn::Sequential,
}
impl SelfAttnDiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?;
let norm2 = LayerNormNoAffine::new(1e-6);
let mlp_ratio = 4;
let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
let n_mods = 9;
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
hidden_size,
n_mods * hidden_size,
vb.pp("adaLN_modulation.1"),
)?);
Ok(Self {
norm1,
attn,
attn2,
norm2,
mlp,
ada_ln_modulation,
})
}
pub fn pre_attention(
&self,
x: &Tensor,
c: &Tensor,
) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
let modulation = self.ada_ln_modulation.forward(c)?;
let chunks = modulation.chunk(9, D::Minus1)?;
let (
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = (
chunks[0].clone(),
chunks[1].clone(),
chunks[2].clone(),
chunks[3].clone(),
chunks[4].clone(),
chunks[5].clone(),
chunks[6].clone(),
chunks[7].clone(),
chunks[8].clone(),
);
let norm_x = self.norm1.forward(x)?;
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
let qkv = self.attn.pre_attention(&modulated_x)?;
let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
let qkv2 = self.attn2.pre_attention(&modulated_x2)?;
Ok((
qkv,
qkv2,
SelfAttnModulateIntermediates {
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
},
))
}
pub fn post_attention(
&self,
attn: &Tensor,
attn2: &Tensor,
x: &Tensor,
mod_interm: &SelfAttnModulateIntermediates,
) -> Result<Tensor> {
let attn_out = self.attn.post_attention(attn)?;
let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
let attn_out2 = self.attn2.post_attention(attn2)?;
let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;
let norm_x = self.norm2.forward(&x)?;
let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
let mlp_out = self.mlp.forward(&modulated_x)?;
let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
Ok(x)
}
}
pub struct QkvOnlyDiTBlock { pub struct QkvOnlyDiTBlock {
norm1: LayerNormNoAffine, norm1: LayerNormNoAffine,
attn: QkvOnlyAttnProjections, attn: QkvOnlyAttnProjections,
@ -190,14 +300,18 @@ fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?) shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
} }
pub struct JointBlock { pub trait JointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>;
}
pub struct MMDiTJointBlock {
x_block: DiTBlock, x_block: DiTBlock,
context_block: DiTBlock, context_block: DiTBlock,
num_heads: usize, num_heads: usize,
use_flash_attn: bool, use_flash_attn: bool,
} }
impl JointBlock { impl MMDiTJointBlock {
pub fn new( pub fn new(
hidden_size: usize, hidden_size: usize,
num_heads: usize, num_heads: usize,
@ -214,8 +328,10 @@ impl JointBlock {
use_flash_attn, use_flash_attn,
}) })
} }
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { impl JointBlock for MMDiTJointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) = let (context_attn, x_attn) =
@ -228,6 +344,49 @@ impl JointBlock {
} }
} }
pub struct MMDiTXJointBlock {
x_block: SelfAttnDiTBlock,
context_block: DiTBlock,
num_heads: usize,
use_flash_attn: bool,
}
impl MMDiTXJointBlock {
pub fn new(
hidden_size: usize,
num_heads: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
Ok(Self {
x_block,
context_block,
num_heads,
use_flash_attn,
})
}
}
impl JointBlock for MMDiTXJointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) =
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?;
let context_out =
self.context_block
.post_attention(&context_attn, context, &context_interm)?;
let x_out = self
.x_block
.post_attention(&x_attn, &x_attn2, x, &x_interm)?;
Ok((context_out, x_out))
}
}
pub struct ContextQkvOnlyJointBlock { pub struct ContextQkvOnlyJointBlock {
x_block: DiTBlock, x_block: DiTBlock,
context_block: QkvOnlyDiTBlock, context_block: QkvOnlyDiTBlock,
@ -309,26 +468,30 @@ fn joint_attn(
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?, v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
}; };
let (batch_size, seqlen, _) = qkv.q.dims3()?; let seqlen = qkv.q.dim(1)?;
let qkv = Qkv { let attn = attn(&qkv, num_heads, use_flash_attn)?;
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
v: qkv.v,
};
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
let attn = if use_flash_attn {
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
} else {
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
};
let attn = attn.reshape((batch_size, seqlen, ()))?;
let context_qkv_seqlen = context_qkv.q.dim(1)?; let context_qkv_seqlen = context_qkv.q.dim(1)?;
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?; let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?; let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
Ok((context_attn, x_attn)) Ok((context_attn, x_attn))
} }
fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result<Tensor> {
let batch_size = qkv.q.dim(0)?;
let seqlen = qkv.q.dim(1)?;
let qkv = Qkv {
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
v: qkv.v.clone(),
};
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
let attn = if use_flash_attn {
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
} else {
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
};
attn.reshape((batch_size, seqlen, ()))
}

View File

@ -1,10 +1,15 @@
// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206). // Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206),
// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
// This follows the implementation of the MMDiT model in the ComfyUI repository. // This follows the implementation of the MMDiT model in the ComfyUI repository.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1 // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
// with MMDiT-X support following the Stability-AI/sd3.5 repository.
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1
use candle::{Module, Result, Tensor, D}; use candle::{Module, Result, Tensor, D};
use candle_nn as nn; use candle_nn as nn;
use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock}; use super::blocks::{
ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
};
use super::embedding::{ use super::embedding::{
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder, PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
}; };
@ -37,6 +42,20 @@ impl Config {
} }
} }
pub fn sd3_5_medium() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 24,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 384,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
pub fn sd3_5_large() -> Self { pub fn sd3_5_large() -> Self {
Self { Self {
patch_size: 2, patch_size: 2,
@ -138,7 +157,7 @@ impl MMDiT {
} }
pub struct MMDiTCore { pub struct MMDiTCore {
joint_blocks: Vec<JointBlock>, joint_blocks: Vec<Box<dyn JointBlock>>,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock, context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
final_layer: FinalLayer, final_layer: FinalLayer,
} }
@ -155,12 +174,24 @@ impl MMDiTCore {
) -> Result<Self> { ) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1); let mut joint_blocks = Vec::with_capacity(depth - 1);
for i in 0..depth - 1 { for i in 0..depth - 1 {
joint_blocks.push(JointBlock::new( let joint_block_vb_pp = format!("joint_blocks.{}", i);
hidden_size, let joint_block: Box<dyn JointBlock> =
num_heads, if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) {
use_flash_attn, Box::new(MMDiTXJointBlock::new(
vb.pp(format!("joint_blocks.{}", i)), hidden_size,
)?); num_heads,
use_flash_attn,
vb.pp(&joint_block_vb_pp),
)?)
} else {
Box::new(MMDiTJointBlock::new(
hidden_size,
num_heads,
use_flash_attn,
vb.pp(&joint_block_vb_pp),
)?)
};
joint_blocks.push(joint_block);
} }
Ok(Self { Ok(Self {