Stable diffusion 3.5 support. (#2578)

* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
This commit is contained in:
Laurent Mazare
2024-10-27 10:01:04 +01:00
committed by GitHub
parent 07849aa595
commit 37e0ab8c64
5 changed files with 209 additions and 85 deletions

View File

@ -36,6 +36,20 @@ impl Config {
frequency_embedding_size: 256,
}
}
pub fn sd3_5_large() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 38,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
}
pub struct MMDiT {

View File

@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections {
pub struct AttnProjections {
head_dim: usize,
qkv: nn::Linear,
ln_k: Option<candle_nn::RmsNorm>,
ln_q: Option<candle_nn::RmsNorm>,
proj: nn::Linear,
}
@ -64,16 +66,42 @@ impl AttnProjections {
let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
(Some(ln_k), Some(ln_q))
} else {
(None, None)
};
Ok(Self {
head_dim,
qkv,
proj,
ln_k,
ln_q,
})
}
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
let qkv = self.qkv.forward(x)?;
split_qkv(&qkv, self.head_dim)
let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
let q = match self.ln_q.as_ref() {
None => q,
Some(l) => {
let (b, t, h) = q.dims3()?;
l.forward(&q.reshape((b, t, (), self.head_dim))?)?
.reshape((b, t, h))?
}
};
let k = match self.ln_k.as_ref() {
None => k,
Some(l) => {
let (b, t, h) = k.dims3()?;
l.forward(&k.reshape((b, t, (), self.head_dim))?)?
.reshape((b, t, h))?
}
};
Ok(Qkv { q, k, v })
}
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {