mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add another specific layer-norm structure.
This commit is contained in:
@ -34,6 +34,34 @@ impl Module for WLayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNormNoWeights {
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNormNoWeights {
|
||||
pub fn new(_size: usize) -> Result<Self> {
|
||||
Ok(Self { eps: 1e-6 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNormNoWeights {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = xs.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = xs.dim(D::Minus1)?;
|
||||
let xs = xs.to_dtype(internal_dtype)?;
|
||||
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let xs = xs.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||
.to_dtype(x_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimestepBlock {
|
||||
mapper: candle_nn::Linear,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
|
||||
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
@ -75,7 +75,7 @@ struct UpBlock {
|
||||
pub struct WDiffNeXt {
|
||||
clip_mapper: candle_nn::Linear,
|
||||
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
||||
seq_norm: WLayerNorm,
|
||||
seq_norm: LayerNormNoWeights,
|
||||
embedding_conv: candle_nn::Conv2d,
|
||||
embedding_ln: WLayerNorm,
|
||||
down_blocks: Vec<DownBlock>,
|
||||
@ -133,7 +133,7 @@ impl WDiffNeXt {
|
||||
};
|
||||
effnet_mappers.push(c)
|
||||
}
|
||||
let seq_norm = WLayerNorm::new(c_cond)?;
|
||||
let seq_norm = LayerNormNoWeights::new(c_cond)?;
|
||||
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
||||
let embedding_conv = candle_nn::conv2d(
|
||||
c_in * patch_size * patch_size,
|
||||
|
@ -1,12 +1,12 @@
|
||||
use super::common::WLayerNorm;
|
||||
use super::common::LayerNormNoWeights;
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MixingResidualBlock {
|
||||
norm1: WLayerNorm,
|
||||
norm1: LayerNormNoWeights,
|
||||
depthwise_conv: candle_nn::Conv2d,
|
||||
norm2: WLayerNorm,
|
||||
norm2: LayerNormNoWeights,
|
||||
channelwise_lin1: candle_nn::Linear,
|
||||
channelwise_lin2: candle_nn::Linear,
|
||||
gammas: Vec<f32>,
|
||||
@ -14,8 +14,8 @@ pub struct MixingResidualBlock {
|
||||
|
||||
impl MixingResidualBlock {
|
||||
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let norm1 = WLayerNorm::new(inp)?;
|
||||
let norm2 = WLayerNorm::new(inp)?;
|
||||
let norm1 = LayerNormNoWeights::new(inp)?;
|
||||
let norm2 = LayerNormNoWeights::new(inp)?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
groups: inp,
|
||||
..Default::default()
|
||||
|
Reference in New Issue
Block a user