Only use classifier free guidance for the prior. (#896)

* Only use classifier free guidance for the prior.

* Add another specific layer-norm structure.

* Tweaks.

* Fix the latent shape.

* Print the prior shape.

* More shape fixes.

* Remove some debugging continue.
This commit is contained in:
Laurent Mazare
2023-09-19 14:13:05 +01:00
committed by GitHub
parent 9cf26c5cff
commit 06e46d7c3b
4 changed files with 125 additions and 74 deletions

View File

@ -34,6 +34,34 @@ impl Module for WLayerNorm {
}
}
#[derive(Debug)]
pub struct LayerNormNoWeights {
eps: f64,
}
impl LayerNormNoWeights {
pub fn new(_size: usize) -> Result<Self> {
Ok(Self { eps: 1e-6 })
}
}
impl Module for LayerNormNoWeights {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let x_dtype = xs.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = xs.dim(D::Minus1)?;
let xs = xs.to_dtype(internal_dtype)?;
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let xs = xs.broadcast_sub(&mean_x)?;
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)
}
}
#[derive(Debug)]
pub struct TimestepBlock {
mapper: candle_nn::Linear,

View File

@ -1,4 +1,4 @@
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@ -75,7 +75,7 @@ struct UpBlock {
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: WLayerNorm,
seq_norm: LayerNormNoWeights,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
@ -133,7 +133,7 @@ impl WDiffNeXt {
};
effnet_mappers.push(c)
}
let seq_norm = WLayerNorm::new(c_cond)?;
let seq_norm = LayerNormNoWeights::new(c_cond)?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
@ -335,6 +335,7 @@ impl WDiffNeXt {
level_outputs.push(xs.clone())
}
level_outputs.reverse();
let mut xs = level_outputs[0].clone();
for (i, up_block) in self.up_blocks.iter().enumerate() {
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {

View File

@ -1,12 +1,12 @@
use super::common::WLayerNorm;
use super::common::LayerNormNoWeights;
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct MixingResidualBlock {
norm1: WLayerNorm,
norm1: LayerNormNoWeights,
depthwise_conv: candle_nn::Conv2d,
norm2: WLayerNorm,
norm2: LayerNormNoWeights,
channelwise_lin1: candle_nn::Linear,
channelwise_lin2: candle_nn::Linear,
gammas: Vec<f32>,
@ -14,8 +14,8 @@ pub struct MixingResidualBlock {
impl MixingResidualBlock {
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let norm1 = WLayerNorm::new(inp)?;
let norm2 = WLayerNorm::new(inp)?;
let norm1 = LayerNormNoWeights::new(inp)?;
let norm2 = LayerNormNoWeights::new(inp)?;
let cfg = candle_nn::Conv2dConfig {
groups: inp,
..Default::default()