mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
DiffNeXt/unet (#859)
* DiffNeXt/unet * Start adding the vae. * VAE residual block. * VAE forward pass. * Add pixel shuffling. * Actually use pixel shuffling.
This commit is contained in:
@ -444,6 +444,18 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
let d5 = self.5.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4, d5])
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
|
@ -189,3 +189,27 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||
}
|
||||
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
||||
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
||||
let (b_size, c, h, w) = xs.dims4()?;
|
||||
let out_c = c / upscale_factor / upscale_factor;
|
||||
xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
|
||||
.permute((0, 1, 4, 2, 5, 3))?
|
||||
.reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
|
||||
}
|
||||
|
||||
pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
|
||||
let (b_size, c, h, w) = xs.dims4()?;
|
||||
let out_c = c * downscale_factor * downscale_factor;
|
||||
xs.reshape((
|
||||
b_size,
|
||||
c,
|
||||
h / downscale_factor,
|
||||
downscale_factor,
|
||||
w / downscale_factor,
|
||||
downscale_factor,
|
||||
))?
|
||||
.permute((0, 1, 3, 5, 2, 4))?
|
||||
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ struct UpBlock {
|
||||
#[derive(Debug)]
|
||||
pub struct WDiffNeXt {
|
||||
clip_mapper: candle_nn::Linear,
|
||||
effnet_mappers: Vec<candle_nn::Conv2d>,
|
||||
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
||||
seq_norm: candle_nn::LayerNorm,
|
||||
embedding_conv: candle_nn::Conv2d,
|
||||
embedding_ln: WLayerNorm,
|
||||
@ -84,6 +84,7 @@ pub struct WDiffNeXt {
|
||||
clf_ln: WLayerNorm,
|
||||
clf_conv: candle_nn::Conv2d,
|
||||
c_r: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
impl WDiffNeXt {
|
||||
@ -102,6 +103,7 @@ impl WDiffNeXt {
|
||||
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
||||
|
||||
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
||||
// TODO: populate effnet_mappers
|
||||
let effnet_mappers = vec![];
|
||||
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
||||
..Default::default()
|
||||
@ -232,6 +234,7 @@ impl WDiffNeXt {
|
||||
clf_ln,
|
||||
clf_conv,
|
||||
c_r,
|
||||
patch_size,
|
||||
})
|
||||
}
|
||||
|
||||
@ -273,14 +276,70 @@ impl WDiffNeXt {
|
||||
};
|
||||
let x_in = xs;
|
||||
|
||||
// TODO: pixel unshuffle.
|
||||
let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?;
|
||||
// TODO: down blocks
|
||||
let level_outputs = xs.clone();
|
||||
// TODO: up blocks
|
||||
let xs = level_outputs;
|
||||
// TODO: pxel shuffle
|
||||
let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?;
|
||||
let mut xs = xs
|
||||
.apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
|
||||
.apply(&self.embedding_conv)?
|
||||
.apply(&self.embedding_ln)?;
|
||||
|
||||
let mut level_outputs = Vec::new();
|
||||
for (i, down_block) in self.down_blocks.iter().enumerate() {
|
||||
if let Some(ln) = &down_block.layer_norm {
|
||||
xs = xs.apply(ln)?
|
||||
}
|
||||
if let Some(conv) = &down_block.conv {
|
||||
xs = xs.apply(conv)?
|
||||
}
|
||||
let skip = match &self.effnet_mappers[i] {
|
||||
None => None,
|
||||
Some(m) => {
|
||||
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
||||
Some(m.forward(&effnet)?)
|
||||
}
|
||||
};
|
||||
for block in down_block.sub_blocks.iter() {
|
||||
xs = block.res_block.forward(&xs, skip.as_ref())?;
|
||||
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||
if let Some(attn_block) = &block.attn_block {
|
||||
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
||||
}
|
||||
}
|
||||
level_outputs.push(xs.clone())
|
||||
}
|
||||
level_outputs.reverse();
|
||||
|
||||
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||
let skip = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
||||
None => None,
|
||||
Some(m) => {
|
||||
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
||||
Some(m.forward(&effnet)?)
|
||||
}
|
||||
};
|
||||
for (j, block) in up_block.sub_blocks.iter().enumerate() {
|
||||
let skip = if j == 0 && i > 0 {
|
||||
Some(&level_outputs[i])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
xs = block.res_block.forward(&xs, skip)?;
|
||||
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||
if let Some(attn_block) = &block.attn_block {
|
||||
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
||||
}
|
||||
}
|
||||
if let Some(ln) = &up_block.layer_norm {
|
||||
xs = xs.apply(ln)?
|
||||
}
|
||||
if let Some(conv) = &up_block.conv {
|
||||
xs = xs.apply(conv)?
|
||||
}
|
||||
}
|
||||
|
||||
let ab = xs
|
||||
.apply(&self.clf_ln)?
|
||||
.apply(&self.clf_conv)?
|
||||
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
|
||||
.chunk(1, 2)?;
|
||||
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
|
||||
(x_in - &ab[0])? / b
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub mod common;
|
||||
pub mod diffnext;
|
||||
pub mod paella_vq;
|
||||
pub mod prior;
|
||||
|
111
candle-transformers/src/models/wuerstchen/paella_vq.rs
Normal file
111
candle-transformers/src/models/wuerstchen/paella_vq.rs
Normal file
@ -0,0 +1,111 @@
|
||||
#![allow(unused)]
|
||||
use super::common::{AttnBlock, ResBlock, TimestepBlock};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MixingResidualBlock {
|
||||
norm1: candle_nn::LayerNorm,
|
||||
depthwise_conv: candle_nn::Conv2d,
|
||||
norm2: candle_nn::LayerNorm,
|
||||
channelwise_lin1: candle_nn::Linear,
|
||||
channelwise_lin2: candle_nn::Linear,
|
||||
gammas: Vec<f32>,
|
||||
}
|
||||
|
||||
impl MixingResidualBlock {
|
||||
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let cfg = candle_nn::LayerNormConfig {
|
||||
affine: false,
|
||||
eps: 1e-6,
|
||||
remove_mean: true,
|
||||
};
|
||||
let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
|
||||
let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
groups: inp,
|
||||
..Default::default()
|
||||
};
|
||||
let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
|
||||
let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
|
||||
let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
|
||||
let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
depthwise_conv,
|
||||
norm2,
|
||||
channelwise_lin1,
|
||||
channelwise_lin2,
|
||||
gammas,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MixingResidualBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mods = &self.gammas;
|
||||
let x_temp = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&self.norm1)?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.affine(1. + mods[0] as f64, mods[1] as f64)?;
|
||||
// TODO: Add the ReplicationPad2d
|
||||
let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
|
||||
let x_temp = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&self.norm2)?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.affine(1. + mods[3] as f64, mods[4] as f64)?;
|
||||
let x_temp = x_temp
|
||||
.permute((0, 2, 3, 1))?
|
||||
.apply(&self.channelwise_lin1)?
|
||||
.gelu()?
|
||||
.apply(&self.channelwise_lin2)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
xs + x_temp * mods[5] as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PaellaVQ {
|
||||
in_block_conv: candle_nn::Conv2d,
|
||||
out_block_conv: candle_nn::Conv2d,
|
||||
down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
|
||||
down_blocks_conv: candle_nn::Conv2d,
|
||||
down_blocks_bn: candle_nn::BatchNorm,
|
||||
up_blocks_conv: candle_nn::Conv2d,
|
||||
up_blocks: Vec<(MixingResidualBlock, Option<candle_nn::ConvTranspose2d>)>,
|
||||
}
|
||||
|
||||
impl PaellaVQ {
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
|
||||
for down_block in self.down_blocks.iter() {
|
||||
if let Some(conv) = &down_block.0 {
|
||||
xs = xs.apply(conv)?
|
||||
}
|
||||
xs = xs.apply(&down_block.1)?
|
||||
}
|
||||
xs.apply(&self.down_blocks_conv)?
|
||||
.apply(&self.down_blocks_bn)
|
||||
// TODO: quantizer
|
||||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.up_blocks_conv)?;
|
||||
for up_block in self.up_blocks.iter() {
|
||||
xs = xs.apply(&up_block.0)?;
|
||||
if let Some(conv) = &up_block.1 {
|
||||
xs = xs.apply(conv)?
|
||||
}
|
||||
}
|
||||
xs.apply(&self.out_block_conv)?
|
||||
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PaellaVQ {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.decode(&self.encode(xs)?)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user