diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index e2e0bf81..a20254d9 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -9,3 +9,4 @@ pub mod segment_anything; pub mod stable_diffusion; pub mod t5; pub mod whisper; +pub mod wuerstchen; diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs new file mode 100644 index 00000000..fc731a59 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -0,0 +1,126 @@ +use candle::{Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22 +#[derive(Debug)] +pub struct WLayerNorm { + inner: candle_nn::LayerNorm, +} + +impl WLayerNorm { + pub fn new(size: usize, vb: VarBuilder) -> Result { + let cfg = candle_nn::layer_norm::LayerNormConfig { + eps: 1e-6, + remove_mean: true, + affine: false, + }; + let inner = candle_nn::layer_norm(size, cfg, vb)?; + Ok(Self { inner }) + } +} + +impl Module for WLayerNorm { + fn forward(&self, xs: &Tensor) -> Result { + xs.permute((0, 2, 3, 1))? + .apply(&self.inner)? + .permute((0, 3, 1, 2)) + } +} + +#[derive(Debug)] +pub struct TimestepBlock { + mapper: candle_nn::Linear, +} + +impl TimestepBlock { + pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result { + let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?; + Ok(Self { mapper }) + } + + pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result { + let ab = self + .mapper + .forward(t)? + .unsqueeze(2)? + .unsqueeze(3)? + .chunk(2, 1)?; + xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1]) + } +} + +#[derive(Debug)] +pub struct GlobalResponseNorm { + gamma: Tensor, + beta: Tensor, +} + +impl GlobalResponseNorm { + pub fn new(dim: usize, vb: VarBuilder) -> Result { + let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?; + let beta = vb.get((1, 1, 1, 1, dim), "beta")?; + Ok(Self { gamma, beta }) + } +} + +impl Module for GlobalResponseNorm { + fn forward(&self, xs: &Tensor) -> Result { + let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?; + let stand_div_norm = + agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?; + (xs.broadcast_mul(&stand_div_norm)? + .broadcast_mul(&self.gamma) + + &self.beta)? + + xs + } +} + +#[derive(Debug)] +pub struct ResBlock { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlock { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result { + let cfg = candle_nn::Conv2dConfig { + padding: ksize / 2, + groups: c, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result { + let x_res = xs; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?, + }; + let xs = xs + .apply(&self.depthwise)? + .apply(&self.norm)? + .permute((0, 2, 3, 1))?; + let xs = xs + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs new file mode 100644 index 00000000..82c973a1 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -0,0 +1,55 @@ +#![allow(unused)] +use super::common::{GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct ResBlockStageB { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlockStageB { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result { + let cfg = candle_nn::Conv2dConfig { + groups: c, + padding: ksize / 2, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result { + let x_res = xs; + let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?, + }; + let xs = xs + .permute((0, 2, 3, 1))? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} + +#[derive(Debug)] +pub struct WDiffNeXt {} diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs new file mode 100644 index 00000000..81755dd1 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -0,0 +1,3 @@ +pub mod common; +pub mod diffnext; +pub mod prior; diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs new file mode 100644 index 00000000..a4e0300c --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -0,0 +1,94 @@ +#![allow(unused)] +use super::common::{ResBlock, TimestepBlock}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct Block { + res_block: ResBlock, + ts_block: TimestepBlock, + // TODO: attn_block: super::common::AttnBlock, +} + +#[derive(Debug)] +pub struct WPrior { + projection: candle_nn::Conv2d, + cond_mapper_lin1: candle_nn::Linear, + cond_mapper_lin2: candle_nn::Linear, + blocks: Vec, + out_ln: super::common::WLayerNorm, + out_conv: candle_nn::Conv2d, + c_r: usize, +} + +impl WPrior { + pub fn new( + c_in: usize, + c: usize, + c_cond: usize, + c_r: usize, + depth: usize, + _nhead: usize, + vb: VarBuilder, + ) -> Result { + let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?; + let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?; + let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?; + let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?; + let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?; + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?; + let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?; + blocks.push(Block { + res_block, + ts_block, + }) + } + Ok(Self { + projection, + cond_mapper_lin1, + cond_mapper_lin2, + blocks, + out_ln, + out_conv, + c_r, + }) + } + + pub fn gen_r_embedding(&self, r: &Tensor) -> Result { + const MAX_POSITIONS: usize = 10000; + let r = (r * MAX_POSITIONS as f64)?; + let half_dim = self.c_r / 2; + let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; + let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? + * -emb)? + .exp()?; + let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; + let emb = if self.c_r % 2 == 1 { + emb.pad_with_zeros(D::Minus1, 0, 1)? + } else { + emb + }; + emb.to_dtype(r.dtype()) + } + + pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result { + let x_in = xs; + let mut xs = xs.apply(&self.projection)?; + // TODO: leaky relu + let c_embed = c + .apply(&self.cond_mapper_lin1)? + .relu()? + .apply(&self.cond_mapper_lin2)?; + let r_embed = self.gen_r_embedding(r)?; + for block in self.blocks.iter() { + xs = block.res_block.forward(&xs, None)?; + xs = block.ts_block.forward(&xs, &r_embed)?; + // TODO: attn + } + let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?; + (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5) + } +}