DiffNeXt/unet (#859)

* DiffNeXt/unet

* Start adding the vae.

* VAE residual block.

* VAE forward pass.

* Add pixel shuffling.

* Actually use pixel shuffling.
This commit is contained in:
Laurent Mazare
2023-09-15 11:14:02 +02:00
committed by GitHub
parent 81a36b8713
commit 2746f2c4be
5 changed files with 216 additions and 9 deletions

View File

@ -444,6 +444,18 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5)
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
let d5 = self.5.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4, d5])
}
}
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));

View File

@ -189,3 +189,27 @@ impl candle::CustomOp1 for SoftmaxLastDim {
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(&SoftmaxLastDim)
}
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;
let out_c = c / upscale_factor / upscale_factor;
xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
.permute((0, 1, 4, 2, 5, 3))?
.reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
}
pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;
let out_c = c * downscale_factor * downscale_factor;
xs.reshape((
b_size,
c,
h / downscale_factor,
downscale_factor,
w / downscale_factor,
downscale_factor,
))?
.permute((0, 1, 3, 5, 2, 4))?
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
}

View File

@ -75,7 +75,7 @@ struct UpBlock {
#[derive(Debug)]
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<candle_nn::Conv2d>,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: candle_nn::LayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
@ -84,6 +84,7 @@ pub struct WDiffNeXt {
clf_ln: WLayerNorm,
clf_conv: candle_nn::Conv2d,
c_r: usize,
patch_size: usize,
}
impl WDiffNeXt {
@ -102,6 +103,7 @@ impl WDiffNeXt {
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
// TODO: populate effnet_mappers
let effnet_mappers = vec![];
let cfg = candle_nn::layer_norm::LayerNormConfig {
..Default::default()
@ -232,6 +234,7 @@ impl WDiffNeXt {
clf_ln,
clf_conv,
c_r,
patch_size,
})
}
@ -273,14 +276,70 @@ impl WDiffNeXt {
};
let x_in = xs;
// TODO: pixel unshuffle.
let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?;
// TODO: down blocks
let level_outputs = xs.clone();
// TODO: up blocks
let xs = level_outputs;
// TODO: pxel shuffle
let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?;
let mut xs = xs
.apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
.apply(&self.embedding_conv)?
.apply(&self.embedding_ln)?;
let mut level_outputs = Vec::new();
for (i, down_block) in self.down_blocks.iter().enumerate() {
if let Some(ln) = &down_block.layer_norm {
xs = xs.apply(ln)?
}
if let Some(conv) = &down_block.conv {
xs = xs.apply(conv)?
}
let skip = match &self.effnet_mappers[i] {
None => None,
Some(m) => {
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
Some(m.forward(&effnet)?)
}
};
for block in down_block.sub_blocks.iter() {
xs = block.res_block.forward(&xs, skip.as_ref())?;
xs = block.ts_block.forward(&xs, &r_embed)?;
if let Some(attn_block) = &block.attn_block {
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
}
}
level_outputs.push(xs.clone())
}
level_outputs.reverse();
for (i, up_block) in self.up_blocks.iter().enumerate() {
let skip = match &self.effnet_mappers[self.down_blocks.len() + i] {
None => None,
Some(m) => {
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
Some(m.forward(&effnet)?)
}
};
for (j, block) in up_block.sub_blocks.iter().enumerate() {
let skip = if j == 0 && i > 0 {
Some(&level_outputs[i])
} else {
None
};
xs = block.res_block.forward(&xs, skip)?;
xs = block.ts_block.forward(&xs, &r_embed)?;
if let Some(attn_block) = &block.attn_block {
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
}
}
if let Some(ln) = &up_block.layer_norm {
xs = xs.apply(ln)?
}
if let Some(conv) = &up_block.conv {
xs = xs.apply(conv)?
}
}
let ab = xs
.apply(&self.clf_ln)?
.apply(&self.clf_conv)?
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
.chunk(1, 2)?;
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
(x_in - &ab[0])? / b
}

View File

@ -1,3 +1,4 @@
pub mod common;
pub mod diffnext;
pub mod paella_vq;
pub mod prior;

View File

@ -0,0 +1,111 @@
#![allow(unused)]
use super::common::{AttnBlock, ResBlock, TimestepBlock};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug)]
struct MixingResidualBlock {
norm1: candle_nn::LayerNorm,
depthwise_conv: candle_nn::Conv2d,
norm2: candle_nn::LayerNorm,
channelwise_lin1: candle_nn::Linear,
channelwise_lin2: candle_nn::Linear,
gammas: Vec<f32>,
}
impl MixingResidualBlock {
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let cfg = candle_nn::LayerNormConfig {
affine: false,
eps: 1e-6,
remove_mean: true,
};
let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
let cfg = candle_nn::Conv2dConfig {
groups: inp,
..Default::default()
};
let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
Ok(Self {
norm1,
depthwise_conv,
norm2,
channelwise_lin1,
channelwise_lin2,
gammas,
})
}
}
impl Module for MixingResidualBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mods = &self.gammas;
let x_temp = xs
.permute((0, 2, 3, 1))?
.apply(&self.norm1)?
.permute((0, 3, 1, 2))?
.affine(1. + mods[0] as f64, mods[1] as f64)?;
// TODO: Add the ReplicationPad2d
let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
let x_temp = xs
.permute((0, 2, 3, 1))?
.apply(&self.norm2)?
.permute((0, 3, 1, 2))?
.affine(1. + mods[3] as f64, mods[4] as f64)?;
let x_temp = x_temp
.permute((0, 2, 3, 1))?
.apply(&self.channelwise_lin1)?
.gelu()?
.apply(&self.channelwise_lin2)?
.permute((0, 3, 1, 2))?;
xs + x_temp * mods[5] as f64
}
}
#[derive(Debug)]
struct PaellaVQ {
in_block_conv: candle_nn::Conv2d,
out_block_conv: candle_nn::Conv2d,
down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
down_blocks_conv: candle_nn::Conv2d,
down_blocks_bn: candle_nn::BatchNorm,
up_blocks_conv: candle_nn::Conv2d,
up_blocks: Vec<(MixingResidualBlock, Option<candle_nn::ConvTranspose2d>)>,
}
impl PaellaVQ {
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
for down_block in self.down_blocks.iter() {
if let Some(conv) = &down_block.0 {
xs = xs.apply(conv)?
}
xs = xs.apply(&down_block.1)?
}
xs.apply(&self.down_blocks_conv)?
.apply(&self.down_blocks_bn)
// TODO: quantizer
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.up_blocks_conv)?;
for up_block in self.up_blocks.iter() {
xs = xs.apply(&up_block.0)?;
if let Some(conv) = &up_block.1 {
xs = xs.apply(conv)?
}
}
xs.apply(&self.out_block_conv)?
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
}
}
impl Module for PaellaVQ {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.decode(&self.encode(xs)?)
}
}