Simplify handling of flux modulations. (#2394)

This commit is contained in:
Laurent Mazare
2024-08-04 10:09:54 +01:00
committed by GitHub
parent 19db6b9723
commit aa7ac1832d

View File

@ -212,24 +212,82 @@ impl QkNorm {
}
}
#[derive(Debug, Clone)]
pub struct Modulation {
lin: Linear,
multiplier: usize,
struct ModulationOut {
shift: Tensor,
scale: Tensor,
gate: Tensor,
}
impl Modulation {
fn new(dim: usize, double: bool, vb: VarBuilder) -> Result<Self> {
let multiplier = if double { 6 } else { 3 };
let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?;
Ok(Self { lin, multiplier })
impl ModulationOut {
fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&(&self.scale + 1.)?)?
.broadcast_add(&self.shift)
}
fn forward(&self, vec_: &Tensor) -> Result<Vec<Tensor>> {
vec_.silu()?
fn gate(&self, xs: &Tensor) -> Result<Tensor> {
self.gate.broadcast_mul(xs)
}
}
#[derive(Debug, Clone)]
struct Modulation1 {
lin: Linear,
}
impl Modulation1 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?;
Ok(Self { lin })
}
fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(self.multiplier, D::Minus1)
.chunk(3, D::Minus1)?;
if ys.len() != 3 {
candle::bail!("unexpected len from chunk {ys:?}")
}
Ok(ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
})
}
}
#[derive(Debug, Clone)]
struct Modulation2 {
lin: Linear,
}
impl Modulation2 {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let lin = candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?;
Ok(Self { lin })
}
fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
let ys = vec_
.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(6, D::Minus1)?;
if ys.len() != 6 {
candle::bail!("unexpected len from chunk {ys:?}")
}
let mod1 = ModulationOut {
shift: ys[0].clone(),
scale: ys[1].clone(),
gate: ys[2].clone(),
};
let mod2 = ModulationOut {
shift: ys[3].clone(),
scale: ys[4].clone(),
gate: ys[5].clone(),
};
Ok((mod1, mod2))
}
}
@ -296,12 +354,12 @@ impl candle::Module for Mlp {
#[derive(Debug, Clone)]
pub struct DoubleStreamBlock {
img_mod: Modulation,
img_mod: Modulation2,
img_norm1: LayerNorm,
img_attn: SelfAttention,
img_norm2: LayerNorm,
img_mlp: Mlp,
txt_mod: Modulation,
txt_mod: Modulation2,
txt_norm1: LayerNorm,
txt_attn: SelfAttention,
txt_norm2: LayerNorm,
@ -312,12 +370,12 @@ impl DoubleStreamBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?;
let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?;
let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
@ -343,18 +401,14 @@ impl DoubleStreamBlock {
vec_: &Tensor,
pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate
let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate
let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate
let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate
let img_modulated = img.apply(&self.img_norm1)?;
let img_modulated = img_modulated
.broadcast_mul(&(&img_mod[1] + 1.)?)?
.broadcast_add(&img_mod[0])?;
let img_modulated = img_mod1.scale_shift(&img_modulated)?;
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
let txt_modulated = txt.apply(&self.txt_norm1)?;
let txt_modulated = txt_modulated
.broadcast_mul(&(&txt_mod[1] + 1.)?)?
.broadcast_add(&txt_mod[0])?;
let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
let q = Tensor::cat(&[txt_q, img_q], 2)?;
@ -365,27 +419,19 @@ impl DoubleStreamBlock {
let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
let img = (img
+ img_attn
.apply(&self.img_attn.proj)?
.broadcast_mul(&img_mod[2]))?;
let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
let img = (&img
+ &img_mod[5].broadcast_mul(
&img.apply(&self.img_norm2)?
.broadcast_mul(&(&img_mod[4] + 1.0)?)?
.broadcast_add(&img_mod[3])?
+ img_mod2.gate(
&img_mod2
.scale_shift(&img.apply(&self.img_norm2)?)?
.apply(&self.img_mlp)?,
)?)?;
let txt = (txt
+ txt_attn
.apply(&self.txt_attn.proj)?
.broadcast_mul(&txt_mod[2]))?;
let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
let txt = (&txt
+ &txt_mod[5].broadcast_mul(
&txt.apply(&self.txt_norm2)?
.broadcast_mul(&(&txt_mod[4] + 1.0)?)?
.broadcast_add(&txt_mod[3])?
+ txt_mod2.gate(
&txt_mod2
.scale_shift(&txt.apply(&self.txt_norm2)?)?
.apply(&self.txt_mlp)?,
)?)?;
@ -399,7 +445,7 @@ pub struct SingleStreamBlock {
linear2: Linear,
norm: QkNorm,
pre_norm: LayerNorm,
modulation: Modulation,
modulation: Modulation1,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
@ -414,7 +460,7 @@ impl SingleStreamBlock {
let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?;
let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
Ok(Self {
linear1,
linear2,
@ -429,11 +475,7 @@ impl SingleStreamBlock {
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
let mod_ = self.modulation.forward(vec_)?;
let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]);
let x_mod = xs
.apply(&self.pre_norm)?
.broadcast_mul(&(scale + 1.0)?)?
.broadcast_add(shift)?;
let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
let x_mod = x_mod.apply(&self.linear1)?;
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
let (b, l, _khd) = qkv.dims3()?;
@ -446,7 +488,7 @@ impl SingleStreamBlock {
let k = k.apply(&self.norm.key_norm)?;
let attn = attention(&q, &k, &v, pe)?;
let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
xs + gate.broadcast_mul(&output)
xs + mod_.gate(&output)
}
}