Add another specific layer-norm structure.

This commit is contained in:
laurent
2023-09-19 09:06:10 +01:00
parent b936e32e11
commit 49a4fa44bb
3 changed files with 36 additions and 8 deletions

View File

@ -34,6 +34,34 @@ impl Module for WLayerNorm {
}
}
#[derive(Debug)]
pub struct LayerNormNoWeights {
eps: f64,
}
impl LayerNormNoWeights {
pub fn new(_size: usize) -> Result<Self> {
Ok(Self { eps: 1e-6 })
}
}
impl Module for LayerNormNoWeights {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let x_dtype = xs.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = xs.dim(D::Minus1)?;
let xs = xs.to_dtype(internal_dtype)?;
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let xs = xs.broadcast_sub(&mean_x)?;
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)
}
}
#[derive(Debug)]
pub struct TimestepBlock {
mapper: candle_nn::Linear,

View File

@ -1,4 +1,4 @@
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@ -75,7 +75,7 @@ struct UpBlock {
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: WLayerNorm,
seq_norm: LayerNormNoWeights,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
@ -133,7 +133,7 @@ impl WDiffNeXt {
};
effnet_mappers.push(c)
}
let seq_norm = WLayerNorm::new(c_cond)?;
let seq_norm = LayerNormNoWeights::new(c_cond)?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,

View File

@ -1,12 +1,12 @@
use super::common::WLayerNorm;
use super::common::LayerNormNoWeights;
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct MixingResidualBlock {
norm1: WLayerNorm,
norm1: LayerNormNoWeights,
depthwise_conv: candle_nn::Conv2d,
norm2: WLayerNorm,
norm2: LayerNormNoWeights,
channelwise_lin1: candle_nn::Linear,
channelwise_lin2: candle_nn::Linear,
gammas: Vec<f32>,
@ -14,8 +14,8 @@ pub struct MixingResidualBlock {
impl MixingResidualBlock {
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let norm1 = WLayerNorm::new(inp)?;
let norm2 = WLayerNorm::new(inp)?;
let norm1 = LayerNormNoWeights::new(inp)?;
let norm2 = LayerNormNoWeights::new(inp)?;
let cfg = candle_nn::Conv2dConfig {
groups: inp,
..Default::default()