diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index b1f56817..4d500e7f 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -444,6 +444,18 @@ impl Dims for (D1, D2, D3, D4, D5) } } +impl Dims for (D1, D2, D3, D4, D5, D6) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + let d5 = self.5.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4, d5]) + } +} + extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index c4055792..16b2e924 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -189,3 +189,27 @@ impl candle::CustomOp1 for SoftmaxLastDim { pub fn softmax_last_dim(xs: &Tensor) -> Result { xs.apply_op1_no_bwd(&SoftmaxLastDim) } + +// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html +pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result { + let (b_size, c, h, w) = xs.dims4()?; + let out_c = c / upscale_factor / upscale_factor; + xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))? + .permute((0, 1, 4, 2, 5, 3))? + .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor)) +} + +pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result { + let (b_size, c, h, w) = xs.dims4()?; + let out_c = c * downscale_factor * downscale_factor; + xs.reshape(( + b_size, + c, + h / downscale_factor, + downscale_factor, + w / downscale_factor, + downscale_factor, + ))? + .permute((0, 1, 3, 5, 2, 4))? + .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor)) +} diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 7289a54d..33ca192e 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -75,7 +75,7 @@ struct UpBlock { #[derive(Debug)] pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, - effnet_mappers: Vec, + effnet_mappers: Vec>, seq_norm: candle_nn::LayerNorm, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, @@ -84,6 +84,7 @@ pub struct WDiffNeXt { clf_ln: WLayerNorm, clf_conv: candle_nn::Conv2d, c_r: usize, + patch_size: usize, } impl WDiffNeXt { @@ -102,6 +103,7 @@ impl WDiffNeXt { const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; + // TODO: populate effnet_mappers let effnet_mappers = vec![]; let cfg = candle_nn::layer_norm::LayerNormConfig { ..Default::default() @@ -232,6 +234,7 @@ impl WDiffNeXt { clf_ln, clf_conv, c_r, + patch_size, }) } @@ -273,14 +276,70 @@ impl WDiffNeXt { }; let x_in = xs; - // TODO: pixel unshuffle. - let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?; - // TODO: down blocks - let level_outputs = xs.clone(); - // TODO: up blocks - let xs = level_outputs; - // TODO: pxel shuffle - let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?; + let mut xs = xs + .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))? + .apply(&self.embedding_conv)? + .apply(&self.embedding_ln)?; + + let mut level_outputs = Vec::new(); + for (i, down_block) in self.down_blocks.iter().enumerate() { + if let Some(ln) = &down_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &down_block.conv { + xs = xs.apply(conv)? + } + let skip = match &self.effnet_mappers[i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for block in down_block.sub_blocks.iter() { + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + level_outputs.push(xs.clone()) + } + level_outputs.reverse(); + + for (i, up_block) in self.up_blocks.iter().enumerate() { + let skip = match &self.effnet_mappers[self.down_blocks.len() + i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for (j, block) in up_block.sub_blocks.iter().enumerate() { + let skip = if j == 0 && i > 0 { + Some(&level_outputs[i]) + } else { + None + }; + xs = block.res_block.forward(&xs, skip)?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + if let Some(ln) = &up_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &up_block.conv { + xs = xs.apply(conv)? + } + } + + let ab = xs + .apply(&self.clf_ln)? + .apply(&self.clf_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))? + .chunk(1, 2)?; let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; (x_in - &ab[0])? / b } diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs index 81755dd1..435bdac2 100644 --- a/candle-transformers/src/models/wuerstchen/mod.rs +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -1,3 +1,4 @@ pub mod common; pub mod diffnext; +pub mod paella_vq; pub mod prior; diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs new file mode 100644 index 00000000..6301b7a1 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -0,0 +1,111 @@ +#![allow(unused)] +use super::common::{AttnBlock, ResBlock, TimestepBlock}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct MixingResidualBlock { + norm1: candle_nn::LayerNorm, + depthwise_conv: candle_nn::Conv2d, + norm2: candle_nn::LayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_lin2: candle_nn::Linear, + gammas: Vec, +} + +impl MixingResidualBlock { + pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result { + let cfg = candle_nn::LayerNormConfig { + affine: false, + eps: 1e-6, + remove_mean: true, + }; + let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?; + let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?; + let cfg = candle_nn::Conv2dConfig { + groups: inp, + ..Default::default() + }; + let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?; + let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?; + let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?; + let gammas = vb.get(6, "gammas")?.to_vec1::()?; + Ok(Self { + norm1, + depthwise_conv, + norm2, + channelwise_lin1, + channelwise_lin2, + gammas, + }) + } +} + +impl Module for MixingResidualBlock { + fn forward(&self, xs: &Tensor) -> Result { + let mods = &self.gammas; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm1)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[0] as f64, mods[1] as f64)?; + // TODO: Add the ReplicationPad2d + let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm2)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[3] as f64, mods[4] as f64)?; + let x_temp = x_temp + .permute((0, 2, 3, 1))? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_temp * mods[5] as f64 + } +} + +#[derive(Debug)] +struct PaellaVQ { + in_block_conv: candle_nn::Conv2d, + out_block_conv: candle_nn::Conv2d, + down_blocks: Vec<(Option, MixingResidualBlock)>, + down_blocks_conv: candle_nn::Conv2d, + down_blocks_bn: candle_nn::BatchNorm, + up_blocks_conv: candle_nn::Conv2d, + up_blocks: Vec<(MixingResidualBlock, Option)>, +} + +impl PaellaVQ { + pub fn encode(&self, xs: &Tensor) -> Result { + let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?; + for down_block in self.down_blocks.iter() { + if let Some(conv) = &down_block.0 { + xs = xs.apply(conv)? + } + xs = xs.apply(&down_block.1)? + } + xs.apply(&self.down_blocks_conv)? + .apply(&self.down_blocks_bn) + // TODO: quantizer + } + + pub fn decode(&self, xs: &Tensor) -> Result { + let mut xs = xs.apply(&self.up_blocks_conv)?; + for up_block in self.up_blocks.iter() { + xs = xs.apply(&up_block.0)?; + if let Some(conv) = &up_block.1 { + xs = xs.apply(conv)? + } + } + xs.apply(&self.out_block_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2)) + } +} + +impl Module for PaellaVQ { + fn forward(&self, xs: &Tensor) -> Result { + self.decode(&self.encode(xs)?) + } +}