Remove the parameters for the Wuerstchen layer-norm. (#879)

* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
This commit is contained in:
Laurent Mazare
2023-09-17 15:59:27 +01:00
committed by GitHub
parent 5f83c13f17
commit 06cc329e71
5 changed files with 45 additions and 45 deletions

View File

@ -302,7 +302,7 @@ pub fn conv_transpose2d_no_bias(
up: bound,
};
let ws = vb.get_with_hints(
(out_channels, in_channels, kernel_size, kernel_size),
(in_channels, out_channels, kernel_size, kernel_size),
"weight",
init,
)?;

View File

@ -1,28 +1,35 @@
use candle::{Module, Result, Tensor, D};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
#[derive(Debug)]
pub struct WLayerNorm {
inner: candle_nn::LayerNorm,
eps: f64,
}
impl WLayerNorm {
pub fn new(size: usize, vb: VarBuilder) -> Result<Self> {
let cfg = candle_nn::layer_norm::LayerNormConfig {
eps: 1e-6,
remove_mean: true,
affine: false,
};
let inner = candle_nn::layer_norm(size, cfg, vb)?;
Ok(Self { inner })
pub fn new(_size: usize) -> Result<Self> {
Ok(Self { eps: 1e-6 })
}
}
impl Module for WLayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.permute((0, 2, 3, 1))?
.apply(&self.inner)?
let xs = xs.permute((0, 2, 3, 1))?;
let x_dtype = xs.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = xs.dim(D::Minus1)?;
let xs = xs.to_dtype(internal_dtype)?;
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let xs = xs.broadcast_sub(&mean_x)?;
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)?
.permute((0, 3, 1, 2))
}
}
@ -57,8 +64,8 @@ pub struct GlobalResponseNorm {
impl GlobalResponseNorm {
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?;
let beta = vb.get((1, 1, 1, 1, dim), "beta")?;
let gamma = vb.get((1, 1, 1, dim), "gamma")?;
let beta = vb.get((1, 1, 1, dim), "beta")?;
Ok(Self { gamma, beta })
}
}
@ -92,7 +99,7 @@ impl ResBlock {
..Default::default()
};
let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
let norm = WLayerNorm::new(c)?;
let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
@ -141,7 +148,7 @@ impl AttnBlock {
self_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
let norm = WLayerNorm::new(c)?;
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {

View File

@ -19,7 +19,7 @@ impl ResBlockStageB {
..Default::default()
};
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
let norm = WLayerNorm::new(c)?;
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
@ -75,7 +75,7 @@ struct UpBlock {
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: candle_nn::LayerNorm,
seq_norm: WLayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
@ -98,7 +98,7 @@ impl WDiffNeXt {
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
const NHEAD: [usize; 4] = [0, 10, 20, 20];
const NHEAD: [usize; 4] = [1, 10, 20, 20];
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
const EFFNET_EMBD: usize = 16;
@ -133,24 +133,21 @@ impl WDiffNeXt {
};
effnet_mappers.push(c)
}
let cfg = candle_nn::layer_norm::LayerNormConfig {
..Default::default()
};
let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?;
let seq_norm = WLayerNorm::new(c_cond)?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
C_HIDDEN[1],
C_HIDDEN[0],
1,
Default::default(),
vb.pp("embedding.2"),
vb.pp("embedding.1"),
)?;
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
let vb = vb.pp("down_blocks").pp(i);
let (layer_norm, conv, start_layer_i) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?;
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
@ -223,7 +220,7 @@ impl WDiffNeXt {
sub_blocks.push(sub_block)
}
let (layer_norm, conv) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
layer_i += 1;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
@ -242,7 +239,7 @@ impl WDiffNeXt {
up_blocks.push(up_block)
}
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
let clf_conv = candle_nn::conv2d(
C_HIDDEN[0],
2 * c_out * patch_size * patch_size,

View File

@ -1,11 +1,12 @@
use super::common::WLayerNorm;
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct MixingResidualBlock {
norm1: candle_nn::LayerNorm,
norm1: WLayerNorm,
depthwise_conv: candle_nn::Conv2d,
norm2: candle_nn::LayerNorm,
norm2: WLayerNorm,
channelwise_lin1: candle_nn::Linear,
channelwise_lin2: candle_nn::Linear,
gammas: Vec<f32>,
@ -13,13 +14,8 @@ pub struct MixingResidualBlock {
impl MixingResidualBlock {
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let cfg = candle_nn::LayerNormConfig {
affine: false,
eps: 1e-6,
remove_mean: true,
};
let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
let norm1 = WLayerNorm::new(inp)?;
let norm2 = WLayerNorm::new(inp)?;
let cfg = candle_nn::Conv2dConfig {
groups: inp,
..Default::default()
@ -120,15 +116,15 @@ impl PaellaVQ {
d_idx += 1;
down_blocks.push((conv_block, res_block))
}
let vb_d = vb_d.pp(d_idx);
let down_blocks_conv = candle_nn::conv2d_no_bias(
C_LEVELS[1],
LATENT_CHANNELS,
1,
Default::default(),
vb_d.pp(d_idx),
vb_d.pp(0),
)?;
d_idx += 1;
let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(d_idx))?;
let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?;
let mut up_blocks = Vec::new();
let vb_u = vb.pp("up_blocks");
@ -138,7 +134,7 @@ impl PaellaVQ {
C_LEVELS[1],
1,
Default::default(),
vb_u.pp(u_idx),
vb_u.pp(u_idx).pp(0),
)?;
u_idx += 1;
for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
@ -157,7 +153,7 @@ impl PaellaVQ {
};
let block = candle_nn::conv_transpose2d_no_bias(
c_level,
C_LEVELS[i - 1],
C_LEVELS[C_LEVELS.len() - i - 2],
4,
cfg,
vb_u.pp(u_idx),

View File

@ -33,7 +33,7 @@ impl WPrior {
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?;
let out_ln = super::common::WLayerNorm::new(c)?;
let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
let mut blocks = Vec::with_capacity(depth);
for index in 0..depth {