diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index aa00077e..4e47873f 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -212,24 +212,82 @@ impl QkNorm { } } -#[derive(Debug, Clone)] -pub struct Modulation { - lin: Linear, - multiplier: usize, +struct ModulationOut { + shift: Tensor, + scale: Tensor, + gate: Tensor, } -impl Modulation { - fn new(dim: usize, double: bool, vb: VarBuilder) -> Result { - let multiplier = if double { 6 } else { 3 }; - let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?; - Ok(Self { lin, multiplier }) +impl ModulationOut { + fn scale_shift(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&(&self.scale + 1.)?)? + .broadcast_add(&self.shift) } - fn forward(&self, vec_: &Tensor) -> Result> { - vec_.silu()? + fn gate(&self, xs: &Tensor) -> Result { + self.gate.broadcast_mul(xs) + } +} + +#[derive(Debug, Clone)] +struct Modulation1 { + lin: Linear, +} + +impl Modulation1 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = candle_nn::linear(dim, 3 * dim, vb.pp("lin"))?; + Ok(Self { lin }) + } + + fn forward(&self, vec_: &Tensor) -> Result { + let ys = vec_ + .silu()? .apply(&self.lin)? .unsqueeze(1)? - .chunk(self.multiplier, D::Minus1) + .chunk(3, D::Minus1)?; + if ys.len() != 3 { + candle::bail!("unexpected len from chunk {ys:?}") + } + Ok(ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }) + } +} + +#[derive(Debug, Clone)] +struct Modulation2 { + lin: Linear, +} + +impl Modulation2 { + fn new(dim: usize, vb: VarBuilder) -> Result { + let lin = candle_nn::linear(dim, 6 * dim, vb.pp("lin"))?; + Ok(Self { lin }) + } + + fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> { + let ys = vec_ + .silu()? + .apply(&self.lin)? + .unsqueeze(1)? + .chunk(6, D::Minus1)?; + if ys.len() != 6 { + candle::bail!("unexpected len from chunk {ys:?}") + } + let mod1 = ModulationOut { + shift: ys[0].clone(), + scale: ys[1].clone(), + gate: ys[2].clone(), + }; + let mod2 = ModulationOut { + shift: ys[3].clone(), + scale: ys[4].clone(), + gate: ys[5].clone(), + }; + Ok((mod1, mod2)) } } @@ -296,12 +354,12 @@ impl candle::Module for Mlp { #[derive(Debug, Clone)] pub struct DoubleStreamBlock { - img_mod: Modulation, + img_mod: Modulation2, img_norm1: LayerNorm, img_attn: SelfAttention, img_norm2: LayerNorm, img_mlp: Mlp, - txt_mod: Modulation, + txt_mod: Modulation2, txt_norm1: LayerNorm, txt_attn: SelfAttention, txt_norm2: LayerNorm, @@ -312,12 +370,12 @@ impl DoubleStreamBlock { fn new(cfg: &Config, vb: VarBuilder) -> Result { let h_sz = cfg.hidden_size; let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize; - let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?; + let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?; let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?; let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?; let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?; let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?; - let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?; + let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?; let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?; let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?; let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?; @@ -343,18 +401,14 @@ impl DoubleStreamBlock { vec_: &Tensor, pe: &Tensor, ) -> Result<(Tensor, Tensor)> { - let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate - let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate + let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; // shift, scale, gate + let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; // shift, scale, gate let img_modulated = img.apply(&self.img_norm1)?; - let img_modulated = img_modulated - .broadcast_mul(&(&img_mod[1] + 1.)?)? - .broadcast_add(&img_mod[0])?; + let img_modulated = img_mod1.scale_shift(&img_modulated)?; let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?; let txt_modulated = txt.apply(&self.txt_norm1)?; - let txt_modulated = txt_modulated - .broadcast_mul(&(&txt_mod[1] + 1.)?)? - .broadcast_add(&txt_mod[0])?; + let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?; let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?; let q = Tensor::cat(&[txt_q, img_q], 2)?; @@ -365,27 +419,19 @@ impl DoubleStreamBlock { let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?; let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?; - let img = (img - + img_attn - .apply(&self.img_attn.proj)? - .broadcast_mul(&img_mod[2]))?; + let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?; let img = (&img - + &img_mod[5].broadcast_mul( - &img.apply(&self.img_norm2)? - .broadcast_mul(&(&img_mod[4] + 1.0)?)? - .broadcast_add(&img_mod[3])? + + img_mod2.gate( + &img_mod2 + .scale_shift(&img.apply(&self.img_norm2)?)? .apply(&self.img_mlp)?, )?)?; - let txt = (txt - + txt_attn - .apply(&self.txt_attn.proj)? - .broadcast_mul(&txt_mod[2]))?; + let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?; let txt = (&txt - + &txt_mod[5].broadcast_mul( - &txt.apply(&self.txt_norm2)? - .broadcast_mul(&(&txt_mod[4] + 1.0)?)? - .broadcast_add(&txt_mod[3])? + + txt_mod2.gate( + &txt_mod2 + .scale_shift(&txt.apply(&self.txt_norm2)?)? .apply(&self.txt_mlp)?, )?)?; @@ -399,7 +445,7 @@ pub struct SingleStreamBlock { linear2: Linear, norm: QkNorm, pre_norm: LayerNorm, - modulation: Modulation, + modulation: Modulation1, h_sz: usize, mlp_sz: usize, num_heads: usize, @@ -414,7 +460,7 @@ impl SingleStreamBlock { let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?; let norm = QkNorm::new(head_dim, vb.pp("norm"))?; let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?; - let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?; + let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?; Ok(Self { linear1, linear2, @@ -429,11 +475,7 @@ impl SingleStreamBlock { fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result { let mod_ = self.modulation.forward(vec_)?; - let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]); - let x_mod = xs - .apply(&self.pre_norm)? - .broadcast_mul(&(scale + 1.0)?)? - .broadcast_add(shift)?; + let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?; let x_mod = x_mod.apply(&self.linear1)?; let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?; let (b, l, _khd) = qkv.dims3()?; @@ -446,7 +488,7 @@ impl SingleStreamBlock { let k = k.apply(&self.norm.key_norm)?; let attn = attention(&q, &k, &v, pe)?; let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?; - xs + gate.broadcast_mul(&output) + xs + mod_.gate(&output) } }