Files
candle/candle-transformers/src/models/mmdit/model.rs
Czxck001 dfdce2b602 Add the MMDiT model of Stable Diffusion 3 (#2397)
* add mmdit of stable diffusion 3

lint

add comments

* correct a misplaced comment

* fix cargo fmt

* fix clippy error

* use bail! instead of assert!

* use get_on_dim in splitting qkv
2024-08-05 19:26:15 +02:00

174 lines
5.5 KiB
Rust

// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
// This follows the implementation of the MMDiT model in the ComfyUI repository.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
use super::embedding::{
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
};
#[derive(Debug, Clone)]
pub struct Config {
pub patch_size: usize,
pub in_channels: usize,
pub out_channels: usize,
pub depth: usize,
pub head_size: usize,
pub adm_in_channels: usize,
pub pos_embed_max_size: usize,
pub context_embed_size: usize,
pub frequency_embedding_size: usize,
}
impl Config {
pub fn sd3() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 24,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
}
pub struct MMDiT {
core: MMDiTCore,
patch_embedder: PatchEmbedder,
pos_embedder: PositionEmbedder,
timestep_embedder: TimestepEmbedder,
vector_embedder: VectorEmbedder,
context_embedder: nn::Linear,
unpatchifier: Unpatchifier,
}
impl MMDiT {
pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
let hidden_size = cfg.head_size * cfg.depth;
let core = MMDiTCore::new(
cfg.depth,
hidden_size,
cfg.depth,
cfg.patch_size,
cfg.out_channels,
vb.clone(),
)?;
let patch_embedder = PatchEmbedder::new(
cfg.patch_size,
cfg.in_channels,
hidden_size,
vb.pp("x_embedder"),
)?;
let pos_embedder = PositionEmbedder::new(
hidden_size,
cfg.patch_size,
cfg.pos_embed_max_size,
vb.clone(),
)?;
let timestep_embedder = TimestepEmbedder::new(
hidden_size,
cfg.frequency_embedding_size,
vb.pp("t_embedder"),
)?;
let vector_embedder =
VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
let context_embedder = nn::linear(
cfg.context_embed_size,
hidden_size,
vb.pp("context_embedder"),
)?;
let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
Ok(Self {
core,
patch_embedder,
pos_embedder,
timestep_embedder,
vector_embedder,
context_embedder,
unpatchifier,
})
}
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// t: (N,) tensor of diffusion timesteps
// y: (N,) tensor of class labels
let h = x.dim(D::Minus2)?;
let w = x.dim(D::Minus1)?;
let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
let x = self
.patch_embedder
.forward(x)?
.broadcast_add(&cropped_pos_embed)?;
let c = self.timestep_embedder.forward(t)?;
let y = self.vector_embedder.forward(y)?;
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
let x = self.core.forward(&context, &x, &c)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
}
pub struct MMDiTCore {
joint_blocks: Vec<JointBlock>,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
final_layer: FinalLayer,
}
impl MMDiTCore {
pub fn new(
depth: usize,
hidden_size: usize,
num_heads: usize,
patch_size: usize,
out_channels: usize,
vb: nn::VarBuilder,
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
for i in 0..depth - 1 {
joint_blocks.push(JointBlock::new(
hidden_size,
num_heads,
vb.pp(format!("joint_blocks.{}", i)),
)?);
}
Ok(Self {
joint_blocks,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
hidden_size,
num_heads,
vb.pp(format!("joint_blocks.{}", depth - 1)),
)?,
final_layer: FinalLayer::new(
hidden_size,
patch_size,
out_channels,
vb.pp("final_layer"),
)?,
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
for joint_block in &self.joint_blocks {
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
self.final_layer.forward(&x, c)
}
}