More DiffNeXt. (#847)

* More DiffNeXt.

* Down blocks.
This commit is contained in:
Laurent Mazare
2023-09-14 22:16:31 +02:00
committed by GitHub
parent 0a647875ec
commit 91ec546feb

View File

@ -1,5 +1,5 @@
#![allow(unused)]
use super::common::{GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
use super::common::{AttnBlock, GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@ -51,6 +51,27 @@ impl ResBlockStageB {
}
}
#[derive(Debug)]
struct SubBlock {
res_block: ResBlockStageB,
ts_block: TimestepBlock,
attn_block: Option<AttnBlock>,
}
#[derive(Debug)]
struct DownBlock {
layer_norm: Option<WLayerNorm>,
conv: Option<candle_nn::Conv2d>,
sub_blocks: Vec<SubBlock>,
}
#[derive(Debug)]
struct UpBlock {
sub_blocks: Vec<SubBlock>,
layer_norm: Option<WLayerNorm>,
conv: Option<candle_nn::Conv2d>,
}
#[derive(Debug)]
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
@ -58,18 +79,27 @@ pub struct WDiffNeXt {
seq_norm: candle_nn::LayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
up_blocks: Vec<UpBlock>,
clf_ln: WLayerNorm,
clf_conv: candle_nn::Conv2d,
c_r: usize,
}
impl WDiffNeXt {
pub fn new(
c_in: usize,
c_out: usize,
vb: VarBuilder,
c_r: usize,
c_cond: usize,
clip_embd: usize,
patch_size: usize,
vb: VarBuilder,
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
const NHEAD: [usize; 4] = [0, 10, 20, 20];
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
let effnet_mappers = vec![];
@ -85,12 +115,124 @@ impl WDiffNeXt {
Default::default(),
vb.pp("embedding.2"),
)?;
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
let vb = vb.pp("down_blocks").pp(i);
let (layer_norm, conv, start_layer_i) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?;
(Some(layer_norm), Some(conv), 2)
} else {
(None, None, 0)
};
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = start_layer_i;
for j in 0..BLOCKS[i] {
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1;
let attn_block = if j == 0 {
None
} else {
let attn_block =
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
layer_i += 1;
Some(attn_block)
};
let sub_block = SubBlock {
res_block,
ts_block,
attn_block,
};
sub_blocks.push(sub_block)
}
let down_block = DownBlock {
layer_norm,
conv,
sub_blocks,
};
down_blocks.push(down_block)
}
// TODO: populate.
let up_blocks = Vec::with_capacity(C_HIDDEN.len());
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
let clf_conv = candle_nn::conv2d(
C_HIDDEN[0],
2 * c_out * patch_size * patch_size,
1,
Default::default(),
vb.pp("clf.1"),
)?;
Ok(Self {
clip_mapper,
effnet_mappers,
seq_norm,
embedding_conv,
embedding_ln,
down_blocks,
up_blocks,
clf_ln,
clf_conv,
c_r,
})
}
fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
const MAX_POSITIONS: usize = 10000;
let r = (r * MAX_POSITIONS as f64)?;
let half_dim = self.c_r / 2;
let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
* -emb)?
.exp()?;
let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
let emb = if self.c_r % 2 == 1 {
emb.pad_with_zeros(D::Minus1, 0, 1)?
} else {
emb
};
emb.to_dtype(r.dtype())
}
fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
}
pub fn forward(
&self,
xs: &Tensor,
r: &Tensor,
effnet: &Tensor,
clip: Option<&Tensor>,
) -> Result<Tensor> {
const EPS: f64 = 1e-3;
let r_embed = self.gen_r_embedding(r)?;
let clip = match clip {
None => None,
Some(clip) => Some(self.gen_c_embeddings(clip)?),
};
let x_in = xs;
// TODO: pixel unshuffle.
let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?;
// TODO: down blocks
let level_outputs = xs.clone();
// TODO: up blocks
let xs = level_outputs;
// TODO: pxel shuffle
let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?;
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
(x_in - &ab[0])? / b
}
}