diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 2b925cee..b3ea91f9 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -78,7 +78,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result, @@ -205,7 +205,7 @@ impl CrossAttention { self.reshape_batch_dim_to_heads(&xs) } - fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result { + pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result { let _enter = self.span.enter(); let query = self.to_q.forward(xs)?; let context = context.unwrap_or(xs).contiguous()?; diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index fc731a59..10e7b19f 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -124,3 +124,44 @@ impl ResBlock { xs + x_res } } +use crate::models::stable_diffusion::attention::CrossAttention as Attention; +#[derive(Debug)] +pub struct AttnBlock { + self_attn: bool, + norm: WLayerNorm, + attention: Attention, + kv_mapper_lin: candle_nn::Linear, +} + +impl AttnBlock { + pub fn new( + c: usize, + c_cond: usize, + nhead: usize, + self_attn: bool, + vb: VarBuilder, + ) -> Result { + let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?; + let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?; + Ok(Self { + self_attn, + norm, + attention, + kv_mapper_lin, + }) + } + + pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result { + let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?; + let norm_xs = self.norm.forward(xs)?; + let kv = if self.self_attn { + let (b_size, channel, _, _) = xs.dims4()?; + let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?; + Tensor::cat(&[&norm_xs, &kv], 1)? + } else { + kv + }; + xs + self.attention.forward(&norm_xs, Some(&kv)) + } +} diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 82c973a1..8e5099f6 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -52,4 +52,45 @@ impl ResBlockStageB { } #[derive(Debug)] -pub struct WDiffNeXt {} +pub struct WDiffNeXt { + clip_mapper: candle_nn::Linear, + effnet_mappers: Vec, + seq_norm: candle_nn::LayerNorm, + embedding_conv: candle_nn::Conv2d, + embedding_ln: WLayerNorm, +} + +impl WDiffNeXt { + pub fn new( + c_in: usize, + c_out: usize, + vb: VarBuilder, + c_cond: usize, + clip_embd: usize, + patch_size: usize, + ) -> Result { + const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; + + let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; + let effnet_mappers = vec![]; + let cfg = candle_nn::layer_norm::LayerNormConfig { + ..Default::default() + }; + let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?; + let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?; + let embedding_conv = candle_nn::conv2d( + c_in * patch_size * patch_size, + C_HIDDEN[1], + 1, + Default::default(), + vb.pp("embedding.2"), + )?; + Ok(Self { + clip_mapper, + effnet_mappers, + seq_norm, + embedding_conv, + embedding_ln, + }) + } +} diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs index a4e0300c..eea70a02 100644 --- a/candle-transformers/src/models/wuerstchen/prior.rs +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -1,5 +1,5 @@ #![allow(unused)] -use super::common::{ResBlock, TimestepBlock}; +use super::common::{AttnBlock, ResBlock, TimestepBlock}; use candle::{DType, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -7,7 +7,7 @@ use candle_nn::VarBuilder; struct Block { res_block: ResBlock, ts_block: TimestepBlock, - // TODO: attn_block: super::common::AttnBlock, + attn_block: AttnBlock, } #[derive(Debug)] @@ -28,7 +28,7 @@ impl WPrior { c_cond: usize, c_r: usize, depth: usize, - _nhead: usize, + nhead: usize, vb: VarBuilder, ) -> Result { let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?; @@ -40,9 +40,17 @@ impl WPrior { for index in 0..depth { let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?; let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?; + let attn_block = AttnBlock::new( + c, + c, + nhead, + true, + vb.pp(format!("blocks.{}", 3 * index + 2)), + )?; blocks.push(Block { res_block, ts_block, + attn_block, }) } Ok(Self { @@ -86,7 +94,7 @@ impl WPrior { for block in self.blocks.iter() { xs = block.res_block.forward(&xs, None)?; xs = block.ts_block.forward(&xs, &r_embed)?; - // TODO: attn + xs = block.attn_block.forward(&xs, &c_embed)?; } let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?; (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)