Only use classifier free guidance for the prior. (#896)

* Only use classifier free guidance for the prior.

* Add another specific layer-norm structure.

* Tweaks.

* Fix the latent shape.

* Print the prior shape.

* More shape fixes.

* Remove some debugging continue.
This commit is contained in:
Laurent Mazare
2023-09-19 14:13:05 +01:00
committed by GitHub
parent 9cf26c5cff
commit 06e46d7c3b
4 changed files with 125 additions and 74 deletions

View File

@ -1,4 +1,4 @@
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@ -75,7 +75,7 @@ struct UpBlock {
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: WLayerNorm,
seq_norm: LayerNormNoWeights,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
@ -133,7 +133,7 @@ impl WDiffNeXt {
};
effnet_mappers.push(c)
}
let seq_norm = WLayerNorm::new(c_cond)?;
let seq_norm = LayerNormNoWeights::new(c_cond)?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
@ -335,6 +335,7 @@ impl WDiffNeXt {
level_outputs.push(xs.clone())
}
level_outputs.reverse();
let mut xs = level_outputs[0].clone();
for (i, up_block) in self.up_blocks.iter().enumerate() {
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {