mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the MMDiT model of Stable Diffusion 3 (#2397)
* add mmdit of stable diffusion 3 lint add comments * correct a misplaced comment * fix cargo fmt * fix clippy error * use bail! instead of assert! * use get_on_dim in splitting qkv
This commit is contained in:
294
candle-transformers/src/models/mmdit/blocks.rs
Normal file
294
candle-transformers/src/models/mmdit/blocks.rs
Normal file
@ -0,0 +1,294 @@
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections};
|
||||
|
||||
pub struct ModulateIntermediates {
|
||||
gate_msa: Tensor,
|
||||
shift_mlp: Tensor,
|
||||
scale_mlp: Tensor,
|
||||
gate_mlp: Tensor,
|
||||
}
|
||||
|
||||
pub struct DiTBlock {
|
||||
norm1: LayerNormNoAffine,
|
||||
attn: AttnProjections,
|
||||
norm2: LayerNormNoAffine,
|
||||
mlp: Mlp,
|
||||
ada_ln_modulation: nn::Sequential,
|
||||
}
|
||||
|
||||
pub struct LayerNormNoAffine {
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNormNoAffine {
|
||||
pub fn new(eps: f64) -> Self {
|
||||
Self { eps }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNormNoAffine {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl DiTBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
// {'hidden_size': 1536, 'num_heads': 24}
|
||||
let norm1 = LayerNormNoAffine::new(1e-6);
|
||||
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
|
||||
let norm2 = LayerNormNoAffine::new(1e-6);
|
||||
let mlp_ratio = 4;
|
||||
let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
|
||||
let n_mods = 6;
|
||||
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
|
||||
hidden_size,
|
||||
n_mods * hidden_size,
|
||||
vb.pp("adaLN_modulation.1"),
|
||||
)?);
|
||||
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> {
|
||||
let modulation = self.ada_ln_modulation.forward(c)?;
|
||||
let chunks = modulation.chunk(6, D::Minus1)?;
|
||||
let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
|
||||
chunks[0].clone(),
|
||||
chunks[1].clone(),
|
||||
chunks[2].clone(),
|
||||
chunks[3].clone(),
|
||||
chunks[4].clone(),
|
||||
chunks[5].clone(),
|
||||
);
|
||||
|
||||
let norm_x = self.norm1.forward(x)?;
|
||||
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
|
||||
let qkv = self.attn.pre_attention(&modulated_x)?;
|
||||
|
||||
Ok((
|
||||
qkv,
|
||||
ModulateIntermediates {
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
pub fn post_attention(
|
||||
&self,
|
||||
attn: &Tensor,
|
||||
x: &Tensor,
|
||||
mod_interm: &ModulateIntermediates,
|
||||
) -> Result<Tensor> {
|
||||
let attn_out = self.attn.post_attention(attn)?;
|
||||
let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
|
||||
|
||||
let norm_x = self.norm2.forward(&x)?;
|
||||
let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
|
||||
let mlp_out = self.mlp.forward(&modulated_x)?;
|
||||
let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
|
||||
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QkvOnlyDiTBlock {
|
||||
norm1: LayerNormNoAffine,
|
||||
attn: QkvOnlyAttnProjections,
|
||||
ada_ln_modulation: nn::Sequential,
|
||||
}
|
||||
|
||||
impl QkvOnlyDiTBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let norm1 = LayerNormNoAffine::new(1e-6);
|
||||
let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
|
||||
let n_mods = 2;
|
||||
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
|
||||
hidden_size,
|
||||
n_mods * hidden_size,
|
||||
vb.pp("adaLN_modulation.1"),
|
||||
)?);
|
||||
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<Qkv> {
|
||||
let modulation = self.ada_ln_modulation.forward(c)?;
|
||||
let chunks = modulation.chunk(2, D::Minus1)?;
|
||||
let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone());
|
||||
|
||||
let norm_x = self.norm1.forward(x)?;
|
||||
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
|
||||
self.attn.pre_attention(&modulated_x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FinalLayer {
|
||||
norm_final: LayerNormNoAffine,
|
||||
linear: nn::Linear,
|
||||
ada_ln_modulation: nn::Sequential,
|
||||
}
|
||||
|
||||
impl FinalLayer {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm_final = LayerNormNoAffine::new(1e-6);
|
||||
let linear = nn::linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
vb.pp("linear"),
|
||||
)?;
|
||||
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
|
||||
hidden_size,
|
||||
2 * hidden_size,
|
||||
vb.pp("adaLN_modulation.1"),
|
||||
)?);
|
||||
|
||||
Ok(Self {
|
||||
norm_final,
|
||||
linear,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let modulation = self.ada_ln_modulation.forward(c)?;
|
||||
let chunks = modulation.chunk(2, D::Minus1)?;
|
||||
let (shift, scale) = (chunks[0].clone(), chunks[1].clone());
|
||||
|
||||
let norm_x = self.norm_final.forward(x)?;
|
||||
let modulated_x = modulate(&norm_x, &shift, &scale)?;
|
||||
let output = self.linear.forward(&modulated_x)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
|
||||
let shift = shift.unsqueeze(1)?;
|
||||
let scale = scale.unsqueeze(1)?;
|
||||
let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?;
|
||||
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
|
||||
}
|
||||
|
||||
pub struct JointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: DiTBlock,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl JointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let context_out =
|
||||
self.context_block
|
||||
.post_attention(&context_attn, context, &context_interm)?;
|
||||
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
|
||||
Ok((context_out, x_out))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContextQkvOnlyJointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: QkvOnlyDiTBlock,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl ContextQkvOnlyJointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let context_qkv = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
|
||||
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
|
||||
Ok(x_out)
|
||||
}
|
||||
}
|
||||
|
||||
// A QKV-attention that is compatible with the interface of candle_flash_attn::flash_attn
|
||||
// Flash attention regards q, k, v dimensions as (batch_size, seqlen, nheads, headdim)
|
||||
fn flash_compatible_attention(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
) -> Result<Tensor> {
|
||||
let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec();
|
||||
let rank = q_dims_for_matmul.len();
|
||||
let q = q.transpose(1, 2)?.flatten_to(rank - 3)?;
|
||||
let k = k.transpose(1, 2)?.flatten_to(rank - 3)?;
|
||||
let v = v.transpose(1, 2)?.flatten_to(rank - 3)?;
|
||||
let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
|
||||
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
|
||||
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
|
||||
let qkv = Qkv {
|
||||
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
|
||||
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
|
||||
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
|
||||
};
|
||||
|
||||
let (batch_size, seqlen, _) = qkv.q.dims3()?;
|
||||
let qkv = Qkv {
|
||||
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
v: qkv.v,
|
||||
};
|
||||
|
||||
let headdim = qkv.q.dim(D::Minus1)?;
|
||||
let softmax_scale = 1.0 / (headdim as f64).sqrt();
|
||||
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
|
||||
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
|
||||
|
||||
let attn = attn.reshape((batch_size, seqlen, ()))?;
|
||||
let context_qkv_seqlen = context_qkv.q.dim(1)?;
|
||||
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
|
||||
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
|
||||
|
||||
Ok((context_attn, x_attn))
|
||||
}
|
197
candle-transformers/src/models/mmdit/embedding.rs
Normal file
197
candle-transformers/src/models/mmdit/embedding.rs
Normal file
@ -0,0 +1,197 @@
|
||||
use candle::{bail, DType, Module, Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
pub struct PatchEmbedder {
|
||||
proj: nn::Conv2d,
|
||||
}
|
||||
|
||||
impl PatchEmbedder {
|
||||
pub fn new(
|
||||
patch_size: usize,
|
||||
in_channels: usize,
|
||||
embed_dim: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let proj = nn::conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("proj"),
|
||||
)?;
|
||||
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbedder {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.proj.forward(x)?;
|
||||
|
||||
// flatten spatial dim and transpose to channels last
|
||||
let (b, c, h, w) = x.dims4()?;
|
||||
x.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Unpatchifier {
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
}
|
||||
|
||||
impl Unpatchifier {
|
||||
pub fn new(patch_size: usize, out_channels: usize) -> Result<Self> {
|
||||
Ok(Self {
|
||||
patch_size,
|
||||
out_channels,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result<Tensor> {
|
||||
let h = (h + 1) / self.patch_size;
|
||||
let w = (w + 1) / self.patch_size;
|
||||
|
||||
let x = x.reshape((
|
||||
x.dim(0)?,
|
||||
h,
|
||||
w,
|
||||
self.patch_size,
|
||||
self.patch_size,
|
||||
self.out_channels,
|
||||
))?;
|
||||
let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq"
|
||||
x.reshape((
|
||||
x.dim(0)?,
|
||||
self.out_channels,
|
||||
self.patch_size * h,
|
||||
self.patch_size * w,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PositionEmbedder {
|
||||
pos_embed: Tensor,
|
||||
patch_size: usize,
|
||||
pos_embed_max_size: usize,
|
||||
}
|
||||
|
||||
impl PositionEmbedder {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
patch_size: usize,
|
||||
pos_embed_max_size: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let pos_embed = vb.get(
|
||||
(1, pos_embed_max_size * pos_embed_max_size, hidden_size),
|
||||
"pos_embed",
|
||||
)?;
|
||||
Ok(Self {
|
||||
pos_embed,
|
||||
patch_size,
|
||||
pos_embed_max_size,
|
||||
})
|
||||
}
|
||||
pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let h = (h + 1) / self.patch_size;
|
||||
let w = (w + 1) / self.patch_size;
|
||||
|
||||
if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
|
||||
bail!("Input size is too large for the position embedding")
|
||||
}
|
||||
|
||||
let top = (self.pos_embed_max_size - h) / 2;
|
||||
let left = (self.pos_embed_max_size - w) / 2;
|
||||
|
||||
let pos_embed =
|
||||
self.pos_embed
|
||||
.reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
|
||||
let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
|
||||
pos_embed.reshape((1, h * w, ()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TimestepEmbedder {
|
||||
mlp: nn::Sequential,
|
||||
frequency_embedding_size: usize,
|
||||
}
|
||||
|
||||
impl TimestepEmbedder {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
frequency_embedding_size: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mlp = nn::seq()
|
||||
.add(nn::linear(
|
||||
frequency_embedding_size,
|
||||
hidden_size,
|
||||
vb.pp("mlp.0"),
|
||||
)?)
|
||||
.add(nn::Activation::Silu)
|
||||
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
|
||||
|
||||
Ok(Self {
|
||||
mlp,
|
||||
frequency_embedding_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result<Tensor> {
|
||||
if dim % 2 != 0 {
|
||||
bail!("Embedding dimension must be even")
|
||||
}
|
||||
|
||||
if t.dtype() != DType::F32 && t.dtype() != DType::F64 {
|
||||
bail!("Input tensor must be floating point")
|
||||
}
|
||||
|
||||
let half = dim / 2;
|
||||
let freqs = Tensor::arange(0f32, half as f32, t.device())?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.mul(&Tensor::full(
|
||||
(-f64::ln(max_period) / half as f64) as f32,
|
||||
half,
|
||||
t.device(),
|
||||
)?)?
|
||||
.exp()?;
|
||||
|
||||
let args = t
|
||||
.unsqueeze(1)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.matmul(&freqs.unsqueeze(0)?)?;
|
||||
let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
|
||||
embedding.to_dtype(candle::DType::F16)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TimestepEmbedder {
|
||||
fn forward(&self, t: &Tensor) -> Result<Tensor> {
|
||||
let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?;
|
||||
self.mlp.forward(&t_freq)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorEmbedder {
|
||||
mlp: nn::Sequential,
|
||||
}
|
||||
|
||||
impl VectorEmbedder {
|
||||
pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let mlp = nn::seq()
|
||||
.add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?)
|
||||
.add(nn::Activation::Silu)
|
||||
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
|
||||
|
||||
Ok(Self { mlp })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VectorEmbedder {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.mlp.forward(x)
|
||||
}
|
||||
}
|
4
candle-transformers/src/models/mmdit/mod.rs
Normal file
4
candle-transformers/src/models/mmdit/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod blocks;
|
||||
pub mod embedding;
|
||||
pub mod model;
|
||||
pub mod projections;
|
173
candle-transformers/src/models/mmdit/model.rs
Normal file
173
candle-transformers/src/models/mmdit/model.rs
Normal file
@ -0,0 +1,173 @@
|
||||
// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
|
||||
// This follows the implementation of the MMDiT model in the ComfyUI repository.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
|
||||
use super::embedding::{
|
||||
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub patch_size: usize,
|
||||
pub in_channels: usize,
|
||||
pub out_channels: usize,
|
||||
pub depth: usize,
|
||||
pub head_size: usize,
|
||||
pub adm_in_channels: usize,
|
||||
pub pos_embed_max_size: usize,
|
||||
pub context_embed_size: usize,
|
||||
pub frequency_embedding_size: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn sd3() -> Self {
|
||||
Self {
|
||||
patch_size: 2,
|
||||
in_channels: 16,
|
||||
out_channels: 16,
|
||||
depth: 24,
|
||||
head_size: 64,
|
||||
adm_in_channels: 2048,
|
||||
pos_embed_max_size: 192,
|
||||
context_embed_size: 4096,
|
||||
frequency_embedding_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMDiT {
|
||||
core: MMDiTCore,
|
||||
patch_embedder: PatchEmbedder,
|
||||
pos_embedder: PositionEmbedder,
|
||||
timestep_embedder: TimestepEmbedder,
|
||||
vector_embedder: VectorEmbedder,
|
||||
context_embedder: nn::Linear,
|
||||
unpatchifier: Unpatchifier,
|
||||
}
|
||||
|
||||
impl MMDiT {
|
||||
pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.head_size * cfg.depth;
|
||||
let core = MMDiTCore::new(
|
||||
cfg.depth,
|
||||
hidden_size,
|
||||
cfg.depth,
|
||||
cfg.patch_size,
|
||||
cfg.out_channels,
|
||||
vb.clone(),
|
||||
)?;
|
||||
let patch_embedder = PatchEmbedder::new(
|
||||
cfg.patch_size,
|
||||
cfg.in_channels,
|
||||
hidden_size,
|
||||
vb.pp("x_embedder"),
|
||||
)?;
|
||||
let pos_embedder = PositionEmbedder::new(
|
||||
hidden_size,
|
||||
cfg.patch_size,
|
||||
cfg.pos_embed_max_size,
|
||||
vb.clone(),
|
||||
)?;
|
||||
let timestep_embedder = TimestepEmbedder::new(
|
||||
hidden_size,
|
||||
cfg.frequency_embedding_size,
|
||||
vb.pp("t_embedder"),
|
||||
)?;
|
||||
let vector_embedder =
|
||||
VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
|
||||
let context_embedder = nn::linear(
|
||||
cfg.context_embed_size,
|
||||
hidden_size,
|
||||
vb.pp("context_embedder"),
|
||||
)?;
|
||||
let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
|
||||
|
||||
Ok(Self {
|
||||
core,
|
||||
patch_embedder,
|
||||
pos_embedder,
|
||||
timestep_embedder,
|
||||
vector_embedder,
|
||||
context_embedder,
|
||||
unpatchifier,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
|
||||
// Following the convention of the ComfyUI implementation.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
|
||||
//
|
||||
// Forward pass of DiT.
|
||||
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
// t: (N,) tensor of diffusion timesteps
|
||||
// y: (N,) tensor of class labels
|
||||
let h = x.dim(D::Minus2)?;
|
||||
let w = x.dim(D::Minus1)?;
|
||||
let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
|
||||
let x = self
|
||||
.patch_embedder
|
||||
.forward(x)?
|
||||
.broadcast_add(&cropped_pos_embed)?;
|
||||
let c = self.timestep_embedder.forward(t)?;
|
||||
let y = self.vector_embedder.forward(y)?;
|
||||
let c = (c + y)?;
|
||||
let context = self.context_embedder.forward(context)?;
|
||||
|
||||
let x = self.core.forward(&context, &x, &c)?;
|
||||
let x = self.unpatchifier.unpatchify(&x, h, w)?;
|
||||
x.narrow(2, 0, h)?.narrow(3, 0, w)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMDiTCore {
|
||||
joint_blocks: Vec<JointBlock>,
|
||||
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
|
||||
final_layer: FinalLayer,
|
||||
}
|
||||
|
||||
impl MMDiTCore {
|
||||
pub fn new(
|
||||
depth: usize,
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut joint_blocks = Vec::with_capacity(depth - 1);
|
||||
for i in 0..depth - 1 {
|
||||
joint_blocks.push(JointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
vb.pp(format!("joint_blocks.{}", i)),
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
joint_blocks,
|
||||
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
vb.pp(format!("joint_blocks.{}", depth - 1)),
|
||||
)?,
|
||||
final_layer: FinalLayer::new(
|
||||
hidden_size,
|
||||
patch_size,
|
||||
out_channels,
|
||||
vb.pp("final_layer"),
|
||||
)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let (mut context, mut x) = (context.clone(), x.clone());
|
||||
for joint_block in &self.joint_blocks {
|
||||
(context, x) = joint_block.forward(&context, &x, c)?;
|
||||
}
|
||||
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
|
||||
self.final_layer.forward(&x, c)
|
||||
}
|
||||
}
|
94
candle-transformers/src/models/mmdit/projections.rs
Normal file
94
candle-transformers/src/models/mmdit/projections.rs
Normal file
@ -0,0 +1,94 @@
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
pub struct Qkv {
|
||||
pub q: Tensor,
|
||||
pub k: Tensor,
|
||||
pub v: Tensor,
|
||||
}
|
||||
|
||||
pub struct Mlp {
|
||||
fc1: nn::Linear,
|
||||
act: nn::Activation,
|
||||
fc2: nn::Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
pub fn new(
|
||||
in_features: usize,
|
||||
hidden_features: usize,
|
||||
vb: candle_nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
|
||||
let act = nn::Activation::GeluPytorchTanh;
|
||||
let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Self { fc1, act, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.fc1.forward(x)?;
|
||||
let x = self.act.forward(&x)?;
|
||||
self.fc2.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QkvOnlyAttnProjections {
|
||||
qkv: nn::Linear,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl QkvOnlyAttnProjections {
|
||||
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
// {'dim': 1536, 'num_heads': 24}
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
Ok(Self { qkv, head_dim })
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
split_qkv(&qkv, self.head_dim)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AttnProjections {
|
||||
head_dim: usize,
|
||||
qkv: nn::Linear,
|
||||
proj: nn::Linear,
|
||||
}
|
||||
|
||||
impl AttnProjections {
|
||||
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
|
||||
Ok(Self {
|
||||
head_dim,
|
||||
qkv,
|
||||
proj,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
split_qkv(&qkv, self.head_dim)
|
||||
}
|
||||
|
||||
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.proj.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result<Qkv> {
|
||||
let (batch_size, seq_len, _) = qkv.dims3()?;
|
||||
let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
|
||||
let q = qkv.get_on_dim(2, 0)?;
|
||||
let q = q.reshape((batch_size, seq_len, ()))?;
|
||||
let k = qkv.get_on_dim(2, 1)?;
|
||||
let k = k.reshape((batch_size, seq_len, ()))?;
|
||||
let v = qkv.get_on_dim(2, 2)?;
|
||||
Ok(Qkv { q, k, v })
|
||||
}
|
@ -32,6 +32,7 @@ pub mod metavoice;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mmdit;
|
||||
pub mod mobilenetv4;
|
||||
pub mod mobileone;
|
||||
pub mod moondream;
|
||||
|
Reference in New Issue
Block a user