mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support. * Clippy fixes. * CFG fix. * Remove some unnecessary clones. * Avoid duplicating some of the code.
This commit is contained in:
@ -36,6 +36,20 @@ impl Config {
|
||||
frequency_embedding_size: 256,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sd3_5_large() -> Self {
|
||||
Self {
|
||||
patch_size: 2,
|
||||
in_channels: 16,
|
||||
out_channels: 16,
|
||||
depth: 38,
|
||||
head_size: 64,
|
||||
adm_in_channels: 2048,
|
||||
pos_embed_max_size: 192,
|
||||
context_embed_size: 4096,
|
||||
frequency_embedding_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMDiT {
|
||||
|
@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections {
|
||||
pub struct AttnProjections {
|
||||
head_dim: usize,
|
||||
qkv: nn::Linear,
|
||||
ln_k: Option<candle_nn::RmsNorm>,
|
||||
ln_q: Option<candle_nn::RmsNorm>,
|
||||
proj: nn::Linear,
|
||||
}
|
||||
|
||||
@ -64,16 +66,42 @@ impl AttnProjections {
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
|
||||
let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
|
||||
let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
|
||||
let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
|
||||
(Some(ln_k), Some(ln_q))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
Ok(Self {
|
||||
head_dim,
|
||||
qkv,
|
||||
proj,
|
||||
ln_k,
|
||||
ln_q,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
split_qkv(&qkv, self.head_dim)
|
||||
let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
|
||||
let q = match self.ln_q.as_ref() {
|
||||
None => q,
|
||||
Some(l) => {
|
||||
let (b, t, h) = q.dims3()?;
|
||||
l.forward(&q.reshape((b, t, (), self.head_dim))?)?
|
||||
.reshape((b, t, h))?
|
||||
}
|
||||
};
|
||||
let k = match self.ln_k.as_ref() {
|
||||
None => k,
|
||||
Some(l) => {
|
||||
let (b, t, h) = k.dims3()?;
|
||||
l.forward(&k.reshape((b, t, (), self.head_dim))?)?
|
||||
.reshape((b, t, h))?
|
||||
}
|
||||
};
|
||||
Ok(Qkv { q, k, v })
|
||||
}
|
||||
|
||||
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user