From 49a4fa44bbfbc83127cc794f2b54108545f789b0 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 19 Sep 2023 09:06:10 +0100 Subject: [PATCH] Add another specific layer-norm structure. --- .../src/models/wuerstchen/common.rs | 28 +++++++++++++++++++ .../src/models/wuerstchen/diffnext.rs | 6 ++-- .../src/models/wuerstchen/paella_vq.rs | 10 +++---- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index 1eb0c2e7..3cac2a59 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -34,6 +34,34 @@ impl Module for WLayerNorm { } } +#[derive(Debug)] +pub struct LayerNormNoWeights { + eps: f64, +} + +impl LayerNormNoWeights { + pub fn new(_size: usize) -> Result { + Ok(Self { eps: 1e-6 }) + } +} + +impl Module for LayerNormNoWeights { + fn forward(&self, xs: &Tensor) -> Result { + let x_dtype = xs.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = xs.dim(D::Minus1)?; + let xs = xs.to_dtype(internal_dtype)?; + let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let xs = xs.broadcast_sub(&mean_x)?; + let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype) + } +} + #[derive(Debug)] pub struct TimestepBlock { mapper: candle_nn::Linear, diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 664251ed..afa83a16 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -1,4 +1,4 @@ -use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm}; +use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm}; use candle::{DType, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -75,7 +75,7 @@ struct UpBlock { pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, effnet_mappers: Vec>, - seq_norm: WLayerNorm, + seq_norm: LayerNormNoWeights, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, down_blocks: Vec, @@ -133,7 +133,7 @@ impl WDiffNeXt { }; effnet_mappers.push(c) } - let seq_norm = WLayerNorm::new(c_cond)?; + let seq_norm = LayerNormNoWeights::new(c_cond)?; let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; let embedding_conv = candle_nn::conv2d( c_in * patch_size * patch_size, diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index faf2d2b4..8cf33505 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -1,12 +1,12 @@ -use super::common::WLayerNorm; +use super::common::LayerNormNoWeights; use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; #[derive(Debug)] pub struct MixingResidualBlock { - norm1: WLayerNorm, + norm1: LayerNormNoWeights, depthwise_conv: candle_nn::Conv2d, - norm2: WLayerNorm, + norm2: LayerNormNoWeights, channelwise_lin1: candle_nn::Linear, channelwise_lin2: candle_nn::Linear, gammas: Vec, @@ -14,8 +14,8 @@ pub struct MixingResidualBlock { impl MixingResidualBlock { pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result { - let norm1 = WLayerNorm::new(inp)?; - let norm2 = WLayerNorm::new(inp)?; + let norm1 = LayerNormNoWeights::new(inp)?; + let norm2 = LayerNormNoWeights::new(inp)?; let cfg = candle_nn::Conv2dConfig { groups: inp, ..Default::default()