Remove the parameters for the Wuerstchen layer-norm. (#879)

* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
This commit is contained in:
Laurent Mazare
2023-09-17 15:59:27 +01:00
committed by GitHub
parent 5f83c13f17
commit 06cc329e71
5 changed files with 45 additions and 45 deletions

View File

@ -1,28 +1,35 @@
use candle::{Module, Result, Tensor, D};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
#[derive(Debug)]
pub struct WLayerNorm {
inner: candle_nn::LayerNorm,
eps: f64,
}
impl WLayerNorm {
pub fn new(size: usize, vb: VarBuilder) -> Result<Self> {
let cfg = candle_nn::layer_norm::LayerNormConfig {
eps: 1e-6,
remove_mean: true,
affine: false,
};
let inner = candle_nn::layer_norm(size, cfg, vb)?;
Ok(Self { inner })
pub fn new(_size: usize) -> Result<Self> {
Ok(Self { eps: 1e-6 })
}
}
impl Module for WLayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.permute((0, 2, 3, 1))?
.apply(&self.inner)?
let xs = xs.permute((0, 2, 3, 1))?;
let x_dtype = xs.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = xs.dim(D::Minus1)?;
let xs = xs.to_dtype(internal_dtype)?;
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let xs = xs.broadcast_sub(&mean_x)?;
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)?
.permute((0, 3, 1, 2))
}
}
@ -57,8 +64,8 @@ pub struct GlobalResponseNorm {
impl GlobalResponseNorm {
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?;
let beta = vb.get((1, 1, 1, 1, dim), "beta")?;
let gamma = vb.get((1, 1, 1, dim), "gamma")?;
let beta = vb.get((1, 1, 1, dim), "beta")?;
Ok(Self { gamma, beta })
}
}
@ -92,7 +99,7 @@ impl ResBlock {
..Default::default()
};
let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
let norm = WLayerNorm::new(c)?;
let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
@ -141,7 +148,7 @@ impl AttnBlock {
self_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
let norm = WLayerNorm::new(c)?;
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {