mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
DiffNeXt/unet (#859)
* DiffNeXt/unet * Start adding the vae. * VAE residual block. * VAE forward pass. * Add pixel shuffling. * Actually use pixel shuffling.
This commit is contained in:
@ -444,6 +444,18 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
|
||||||
|
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||||
|
let d0 = self.0.to_index(shape, op)?;
|
||||||
|
let d1 = self.1.to_index(shape, op)?;
|
||||||
|
let d2 = self.2.to_index(shape, op)?;
|
||||||
|
let d3 = self.3.to_index(shape, op)?;
|
||||||
|
let d4 = self.4.to_index(shape, op)?;
|
||||||
|
let d5 = self.5.to_index(shape, op)?;
|
||||||
|
Ok(vec![d0, d1, d2, d3, d4, d5])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||||
|
@ -189,3 +189,27 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||||
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
||||||
|
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
||||||
|
let (b_size, c, h, w) = xs.dims4()?;
|
||||||
|
let out_c = c / upscale_factor / upscale_factor;
|
||||||
|
xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
|
||||||
|
.permute((0, 1, 4, 2, 5, 3))?
|
||||||
|
.reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
|
||||||
|
let (b_size, c, h, w) = xs.dims4()?;
|
||||||
|
let out_c = c * downscale_factor * downscale_factor;
|
||||||
|
xs.reshape((
|
||||||
|
b_size,
|
||||||
|
c,
|
||||||
|
h / downscale_factor,
|
||||||
|
downscale_factor,
|
||||||
|
w / downscale_factor,
|
||||||
|
downscale_factor,
|
||||||
|
))?
|
||||||
|
.permute((0, 1, 3, 5, 2, 4))?
|
||||||
|
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
|
||||||
|
}
|
||||||
|
@ -75,7 +75,7 @@ struct UpBlock {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WDiffNeXt {
|
pub struct WDiffNeXt {
|
||||||
clip_mapper: candle_nn::Linear,
|
clip_mapper: candle_nn::Linear,
|
||||||
effnet_mappers: Vec<candle_nn::Conv2d>,
|
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
||||||
seq_norm: candle_nn::LayerNorm,
|
seq_norm: candle_nn::LayerNorm,
|
||||||
embedding_conv: candle_nn::Conv2d,
|
embedding_conv: candle_nn::Conv2d,
|
||||||
embedding_ln: WLayerNorm,
|
embedding_ln: WLayerNorm,
|
||||||
@ -84,6 +84,7 @@ pub struct WDiffNeXt {
|
|||||||
clf_ln: WLayerNorm,
|
clf_ln: WLayerNorm,
|
||||||
clf_conv: candle_nn::Conv2d,
|
clf_conv: candle_nn::Conv2d,
|
||||||
c_r: usize,
|
c_r: usize,
|
||||||
|
patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WDiffNeXt {
|
impl WDiffNeXt {
|
||||||
@ -102,6 +103,7 @@ impl WDiffNeXt {
|
|||||||
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
||||||
|
|
||||||
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
||||||
|
// TODO: populate effnet_mappers
|
||||||
let effnet_mappers = vec![];
|
let effnet_mappers = vec![];
|
||||||
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@ -232,6 +234,7 @@ impl WDiffNeXt {
|
|||||||
clf_ln,
|
clf_ln,
|
||||||
clf_conv,
|
clf_conv,
|
||||||
c_r,
|
c_r,
|
||||||
|
patch_size,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,14 +276,70 @@ impl WDiffNeXt {
|
|||||||
};
|
};
|
||||||
let x_in = xs;
|
let x_in = xs;
|
||||||
|
|
||||||
// TODO: pixel unshuffle.
|
let mut xs = xs
|
||||||
let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?;
|
.apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
|
||||||
// TODO: down blocks
|
.apply(&self.embedding_conv)?
|
||||||
let level_outputs = xs.clone();
|
.apply(&self.embedding_ln)?;
|
||||||
// TODO: up blocks
|
|
||||||
let xs = level_outputs;
|
let mut level_outputs = Vec::new();
|
||||||
// TODO: pxel shuffle
|
for (i, down_block) in self.down_blocks.iter().enumerate() {
|
||||||
let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?;
|
if let Some(ln) = &down_block.layer_norm {
|
||||||
|
xs = xs.apply(ln)?
|
||||||
|
}
|
||||||
|
if let Some(conv) = &down_block.conv {
|
||||||
|
xs = xs.apply(conv)?
|
||||||
|
}
|
||||||
|
let skip = match &self.effnet_mappers[i] {
|
||||||
|
None => None,
|
||||||
|
Some(m) => {
|
||||||
|
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
||||||
|
Some(m.forward(&effnet)?)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for block in down_block.sub_blocks.iter() {
|
||||||
|
xs = block.res_block.forward(&xs, skip.as_ref())?;
|
||||||
|
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||||
|
if let Some(attn_block) = &block.attn_block {
|
||||||
|
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
level_outputs.push(xs.clone())
|
||||||
|
}
|
||||||
|
level_outputs.reverse();
|
||||||
|
|
||||||
|
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||||
|
let skip = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
||||||
|
None => None,
|
||||||
|
Some(m) => {
|
||||||
|
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
||||||
|
Some(m.forward(&effnet)?)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (j, block) in up_block.sub_blocks.iter().enumerate() {
|
||||||
|
let skip = if j == 0 && i > 0 {
|
||||||
|
Some(&level_outputs[i])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
xs = block.res_block.forward(&xs, skip)?;
|
||||||
|
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||||
|
if let Some(attn_block) = &block.attn_block {
|
||||||
|
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(ln) = &up_block.layer_norm {
|
||||||
|
xs = xs.apply(ln)?
|
||||||
|
}
|
||||||
|
if let Some(conv) = &up_block.conv {
|
||||||
|
xs = xs.apply(conv)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ab = xs
|
||||||
|
.apply(&self.clf_ln)?
|
||||||
|
.apply(&self.clf_conv)?
|
||||||
|
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
|
||||||
|
.chunk(1, 2)?;
|
||||||
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
|
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
|
||||||
(x_in - &ab[0])? / b
|
(x_in - &ab[0])? / b
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod diffnext;
|
pub mod diffnext;
|
||||||
|
pub mod paella_vq;
|
||||||
pub mod prior;
|
pub mod prior;
|
||||||
|
111
candle-transformers/src/models/wuerstchen/paella_vq.rs
Normal file
111
candle-transformers/src/models/wuerstchen/paella_vq.rs
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::common::{AttnBlock, ResBlock, TimestepBlock};
|
||||||
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct MixingResidualBlock {
|
||||||
|
norm1: candle_nn::LayerNorm,
|
||||||
|
depthwise_conv: candle_nn::Conv2d,
|
||||||
|
norm2: candle_nn::LayerNorm,
|
||||||
|
channelwise_lin1: candle_nn::Linear,
|
||||||
|
channelwise_lin2: candle_nn::Linear,
|
||||||
|
gammas: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MixingResidualBlock {
|
||||||
|
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let cfg = candle_nn::LayerNormConfig {
|
||||||
|
affine: false,
|
||||||
|
eps: 1e-6,
|
||||||
|
remove_mean: true,
|
||||||
|
};
|
||||||
|
let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
|
||||||
|
let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
|
||||||
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
|
groups: inp,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
|
||||||
|
let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
|
||||||
|
let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
|
||||||
|
let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
|
||||||
|
Ok(Self {
|
||||||
|
norm1,
|
||||||
|
depthwise_conv,
|
||||||
|
norm2,
|
||||||
|
channelwise_lin1,
|
||||||
|
channelwise_lin2,
|
||||||
|
gammas,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MixingResidualBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mods = &self.gammas;
|
||||||
|
let x_temp = xs
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&self.norm1)?
|
||||||
|
.permute((0, 3, 1, 2))?
|
||||||
|
.affine(1. + mods[0] as f64, mods[1] as f64)?;
|
||||||
|
// TODO: Add the ReplicationPad2d
|
||||||
|
let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
|
||||||
|
let x_temp = xs
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&self.norm2)?
|
||||||
|
.permute((0, 3, 1, 2))?
|
||||||
|
.affine(1. + mods[3] as f64, mods[4] as f64)?;
|
||||||
|
let x_temp = x_temp
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&self.channelwise_lin1)?
|
||||||
|
.gelu()?
|
||||||
|
.apply(&self.channelwise_lin2)?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
|
xs + x_temp * mods[5] as f64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct PaellaVQ {
|
||||||
|
in_block_conv: candle_nn::Conv2d,
|
||||||
|
out_block_conv: candle_nn::Conv2d,
|
||||||
|
down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
|
||||||
|
down_blocks_conv: candle_nn::Conv2d,
|
||||||
|
down_blocks_bn: candle_nn::BatchNorm,
|
||||||
|
up_blocks_conv: candle_nn::Conv2d,
|
||||||
|
up_blocks: Vec<(MixingResidualBlock, Option<candle_nn::ConvTranspose2d>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PaellaVQ {
|
||||||
|
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
|
||||||
|
for down_block in self.down_blocks.iter() {
|
||||||
|
if let Some(conv) = &down_block.0 {
|
||||||
|
xs = xs.apply(conv)?
|
||||||
|
}
|
||||||
|
xs = xs.apply(&down_block.1)?
|
||||||
|
}
|
||||||
|
xs.apply(&self.down_blocks_conv)?
|
||||||
|
.apply(&self.down_blocks_bn)
|
||||||
|
// TODO: quantizer
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.apply(&self.up_blocks_conv)?;
|
||||||
|
for up_block in self.up_blocks.iter() {
|
||||||
|
xs = xs.apply(&up_block.0)?;
|
||||||
|
if let Some(conv) = &up_block.1 {
|
||||||
|
xs = xs.apply(conv)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xs.apply(&self.out_block_conv)?
|
||||||
|
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for PaellaVQ {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self.decode(&self.encode(xs)?)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user