mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm. * Fixes. * More fixes (including conv-transpose2d. * More fixes. * Again more fixes.
This commit is contained in:
@ -302,7 +302,7 @@ pub fn conv_transpose2d_no_bias(
|
|||||||
up: bound,
|
up: bound,
|
||||||
};
|
};
|
||||||
let ws = vb.get_with_hints(
|
let ws = vb.get_with_hints(
|
||||||
(out_channels, in_channels, kernel_size, kernel_size),
|
(in_channels, out_channels, kernel_size, kernel_size),
|
||||||
"weight",
|
"weight",
|
||||||
init,
|
init,
|
||||||
)?;
|
)?;
|
||||||
|
@ -1,28 +1,35 @@
|
|||||||
use candle::{Module, Result, Tensor, D};
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
|
// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WLayerNorm {
|
pub struct WLayerNorm {
|
||||||
inner: candle_nn::LayerNorm,
|
eps: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WLayerNorm {
|
impl WLayerNorm {
|
||||||
pub fn new(size: usize, vb: VarBuilder) -> Result<Self> {
|
pub fn new(_size: usize) -> Result<Self> {
|
||||||
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
Ok(Self { eps: 1e-6 })
|
||||||
eps: 1e-6,
|
|
||||||
remove_mean: true,
|
|
||||||
affine: false,
|
|
||||||
};
|
|
||||||
let inner = candle_nn::layer_norm(size, cfg, vb)?;
|
|
||||||
Ok(Self { inner })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for WLayerNorm {
|
impl Module for WLayerNorm {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
xs.permute((0, 2, 3, 1))?
|
let xs = xs.permute((0, 2, 3, 1))?;
|
||||||
.apply(&self.inner)?
|
|
||||||
|
let x_dtype = xs.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
|
||||||
|
let hidden_size = xs.dim(D::Minus1)?;
|
||||||
|
let xs = xs.to_dtype(internal_dtype)?;
|
||||||
|
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
let xs = xs.broadcast_sub(&mean_x)?;
|
||||||
|
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||||
|
.to_dtype(x_dtype)?
|
||||||
.permute((0, 3, 1, 2))
|
.permute((0, 3, 1, 2))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -57,8 +64,8 @@ pub struct GlobalResponseNorm {
|
|||||||
|
|
||||||
impl GlobalResponseNorm {
|
impl GlobalResponseNorm {
|
||||||
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?;
|
let gamma = vb.get((1, 1, 1, dim), "gamma")?;
|
||||||
let beta = vb.get((1, 1, 1, 1, dim), "beta")?;
|
let beta = vb.get((1, 1, 1, dim), "beta")?;
|
||||||
Ok(Self { gamma, beta })
|
Ok(Self { gamma, beta })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -92,7 +99,7 @@ impl ResBlock {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
|
let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
|
||||||
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
let norm = WLayerNorm::new(c)?;
|
||||||
let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
|
let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
|
||||||
let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
|
let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
|
||||||
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
||||||
@ -141,7 +148,7 @@ impl AttnBlock {
|
|||||||
self_attn: bool,
|
self_attn: bool,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
let norm = WLayerNorm::new(c)?;
|
||||||
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
|
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
|
||||||
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -19,7 +19,7 @@ impl ResBlockStageB {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
|
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
|
||||||
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
let norm = WLayerNorm::new(c)?;
|
||||||
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
|
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
|
||||||
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
|
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
|
||||||
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
||||||
@ -75,7 +75,7 @@ struct UpBlock {
|
|||||||
pub struct WDiffNeXt {
|
pub struct WDiffNeXt {
|
||||||
clip_mapper: candle_nn::Linear,
|
clip_mapper: candle_nn::Linear,
|
||||||
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
||||||
seq_norm: candle_nn::LayerNorm,
|
seq_norm: WLayerNorm,
|
||||||
embedding_conv: candle_nn::Conv2d,
|
embedding_conv: candle_nn::Conv2d,
|
||||||
embedding_ln: WLayerNorm,
|
embedding_ln: WLayerNorm,
|
||||||
down_blocks: Vec<DownBlock>,
|
down_blocks: Vec<DownBlock>,
|
||||||
@ -98,7 +98,7 @@ impl WDiffNeXt {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
||||||
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
|
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
|
||||||
const NHEAD: [usize; 4] = [0, 10, 20, 20];
|
const NHEAD: [usize; 4] = [1, 10, 20, 20];
|
||||||
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
||||||
const EFFNET_EMBD: usize = 16;
|
const EFFNET_EMBD: usize = 16;
|
||||||
|
|
||||||
@ -133,24 +133,21 @@ impl WDiffNeXt {
|
|||||||
};
|
};
|
||||||
effnet_mappers.push(c)
|
effnet_mappers.push(c)
|
||||||
}
|
}
|
||||||
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
let seq_norm = WLayerNorm::new(c_cond)?;
|
||||||
..Default::default()
|
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
||||||
};
|
|
||||||
let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?;
|
|
||||||
let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?;
|
|
||||||
let embedding_conv = candle_nn::conv2d(
|
let embedding_conv = candle_nn::conv2d(
|
||||||
c_in * patch_size * patch_size,
|
c_in * patch_size * patch_size,
|
||||||
C_HIDDEN[1],
|
C_HIDDEN[0],
|
||||||
1,
|
1,
|
||||||
Default::default(),
|
Default::default(),
|
||||||
vb.pp("embedding.2"),
|
vb.pp("embedding.1"),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
|
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
|
||||||
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
|
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
|
||||||
let vb = vb.pp("down_blocks").pp(i);
|
let vb = vb.pp("down_blocks").pp(i);
|
||||||
let (layer_norm, conv, start_layer_i) = if i > 0 {
|
let (layer_norm, conv, start_layer_i) = if i > 0 {
|
||||||
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?;
|
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
|
||||||
let cfg = candle_nn::Conv2dConfig {
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
stride: 2,
|
stride: 2,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@ -223,7 +220,7 @@ impl WDiffNeXt {
|
|||||||
sub_blocks.push(sub_block)
|
sub_blocks.push(sub_block)
|
||||||
}
|
}
|
||||||
let (layer_norm, conv) = if i > 0 {
|
let (layer_norm, conv) = if i > 0 {
|
||||||
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
|
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
|
||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
let cfg = candle_nn::Conv2dConfig {
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
stride: 2,
|
stride: 2,
|
||||||
@ -242,7 +239,7 @@ impl WDiffNeXt {
|
|||||||
up_blocks.push(up_block)
|
up_blocks.push(up_block)
|
||||||
}
|
}
|
||||||
|
|
||||||
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
|
let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
||||||
let clf_conv = candle_nn::conv2d(
|
let clf_conv = candle_nn::conv2d(
|
||||||
C_HIDDEN[0],
|
C_HIDDEN[0],
|
||||||
2 * c_out * patch_size * patch_size,
|
2 * c_out * patch_size * patch_size,
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
|
use super::common::WLayerNorm;
|
||||||
use candle::{Module, Result, Tensor};
|
use candle::{Module, Result, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MixingResidualBlock {
|
pub struct MixingResidualBlock {
|
||||||
norm1: candle_nn::LayerNorm,
|
norm1: WLayerNorm,
|
||||||
depthwise_conv: candle_nn::Conv2d,
|
depthwise_conv: candle_nn::Conv2d,
|
||||||
norm2: candle_nn::LayerNorm,
|
norm2: WLayerNorm,
|
||||||
channelwise_lin1: candle_nn::Linear,
|
channelwise_lin1: candle_nn::Linear,
|
||||||
channelwise_lin2: candle_nn::Linear,
|
channelwise_lin2: candle_nn::Linear,
|
||||||
gammas: Vec<f32>,
|
gammas: Vec<f32>,
|
||||||
@ -13,13 +14,8 @@ pub struct MixingResidualBlock {
|
|||||||
|
|
||||||
impl MixingResidualBlock {
|
impl MixingResidualBlock {
|
||||||
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let cfg = candle_nn::LayerNormConfig {
|
let norm1 = WLayerNorm::new(inp)?;
|
||||||
affine: false,
|
let norm2 = WLayerNorm::new(inp)?;
|
||||||
eps: 1e-6,
|
|
||||||
remove_mean: true,
|
|
||||||
};
|
|
||||||
let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
|
|
||||||
let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
|
|
||||||
let cfg = candle_nn::Conv2dConfig {
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
groups: inp,
|
groups: inp,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@ -120,15 +116,15 @@ impl PaellaVQ {
|
|||||||
d_idx += 1;
|
d_idx += 1;
|
||||||
down_blocks.push((conv_block, res_block))
|
down_blocks.push((conv_block, res_block))
|
||||||
}
|
}
|
||||||
|
let vb_d = vb_d.pp(d_idx);
|
||||||
let down_blocks_conv = candle_nn::conv2d_no_bias(
|
let down_blocks_conv = candle_nn::conv2d_no_bias(
|
||||||
C_LEVELS[1],
|
C_LEVELS[1],
|
||||||
LATENT_CHANNELS,
|
LATENT_CHANNELS,
|
||||||
1,
|
1,
|
||||||
Default::default(),
|
Default::default(),
|
||||||
vb_d.pp(d_idx),
|
vb_d.pp(0),
|
||||||
)?;
|
)?;
|
||||||
d_idx += 1;
|
let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?;
|
||||||
let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(d_idx))?;
|
|
||||||
|
|
||||||
let mut up_blocks = Vec::new();
|
let mut up_blocks = Vec::new();
|
||||||
let vb_u = vb.pp("up_blocks");
|
let vb_u = vb.pp("up_blocks");
|
||||||
@ -138,7 +134,7 @@ impl PaellaVQ {
|
|||||||
C_LEVELS[1],
|
C_LEVELS[1],
|
||||||
1,
|
1,
|
||||||
Default::default(),
|
Default::default(),
|
||||||
vb_u.pp(u_idx),
|
vb_u.pp(u_idx).pp(0),
|
||||||
)?;
|
)?;
|
||||||
u_idx += 1;
|
u_idx += 1;
|
||||||
for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
|
for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
|
||||||
@ -157,7 +153,7 @@ impl PaellaVQ {
|
|||||||
};
|
};
|
||||||
let block = candle_nn::conv_transpose2d_no_bias(
|
let block = candle_nn::conv_transpose2d_no_bias(
|
||||||
c_level,
|
c_level,
|
||||||
C_LEVELS[i - 1],
|
C_LEVELS[C_LEVELS.len() - i - 2],
|
||||||
4,
|
4,
|
||||||
cfg,
|
cfg,
|
||||||
vb_u.pp(u_idx),
|
vb_u.pp(u_idx),
|
||||||
|
@ -33,7 +33,7 @@ impl WPrior {
|
|||||||
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
||||||
let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
|
let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
|
||||||
let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
|
let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
|
||||||
let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?;
|
let out_ln = super::common::WLayerNorm::new(c)?;
|
||||||
let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
|
let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
|
||||||
let mut blocks = Vec::with_capacity(depth);
|
let mut blocks = Vec::with_capacity(depth);
|
||||||
for index in 0..depth {
|
for index in 0..depth {
|
||||||
|
Reference in New Issue
Block a user