mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add another specific layer-norm structure.
This commit is contained in:
@ -34,6 +34,34 @@ impl Module for WLayerNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct LayerNormNoWeights {
|
||||||
|
eps: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerNormNoWeights {
|
||||||
|
pub fn new(_size: usize) -> Result<Self> {
|
||||||
|
Ok(Self { eps: 1e-6 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for LayerNormNoWeights {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let x_dtype = xs.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
let hidden_size = xs.dim(D::Minus1)?;
|
||||||
|
let xs = xs.to_dtype(internal_dtype)?;
|
||||||
|
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
let xs = xs.broadcast_sub(&mean_x)?;
|
||||||
|
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||||
|
.to_dtype(x_dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TimestepBlock {
|
pub struct TimestepBlock {
|
||||||
mapper: candle_nn::Linear,
|
mapper: candle_nn::Linear,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
|
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
|
||||||
use candle::{DType, Module, Result, Tensor, D};
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ struct UpBlock {
|
|||||||
pub struct WDiffNeXt {
|
pub struct WDiffNeXt {
|
||||||
clip_mapper: candle_nn::Linear,
|
clip_mapper: candle_nn::Linear,
|
||||||
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
||||||
seq_norm: WLayerNorm,
|
seq_norm: LayerNormNoWeights,
|
||||||
embedding_conv: candle_nn::Conv2d,
|
embedding_conv: candle_nn::Conv2d,
|
||||||
embedding_ln: WLayerNorm,
|
embedding_ln: WLayerNorm,
|
||||||
down_blocks: Vec<DownBlock>,
|
down_blocks: Vec<DownBlock>,
|
||||||
@ -133,7 +133,7 @@ impl WDiffNeXt {
|
|||||||
};
|
};
|
||||||
effnet_mappers.push(c)
|
effnet_mappers.push(c)
|
||||||
}
|
}
|
||||||
let seq_norm = WLayerNorm::new(c_cond)?;
|
let seq_norm = LayerNormNoWeights::new(c_cond)?;
|
||||||
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
||||||
let embedding_conv = candle_nn::conv2d(
|
let embedding_conv = candle_nn::conv2d(
|
||||||
c_in * patch_size * patch_size,
|
c_in * patch_size * patch_size,
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
use super::common::WLayerNorm;
|
use super::common::LayerNormNoWeights;
|
||||||
use candle::{Module, Result, Tensor};
|
use candle::{Module, Result, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MixingResidualBlock {
|
pub struct MixingResidualBlock {
|
||||||
norm1: WLayerNorm,
|
norm1: LayerNormNoWeights,
|
||||||
depthwise_conv: candle_nn::Conv2d,
|
depthwise_conv: candle_nn::Conv2d,
|
||||||
norm2: WLayerNorm,
|
norm2: LayerNormNoWeights,
|
||||||
channelwise_lin1: candle_nn::Linear,
|
channelwise_lin1: candle_nn::Linear,
|
||||||
channelwise_lin2: candle_nn::Linear,
|
channelwise_lin2: candle_nn::Linear,
|
||||||
gammas: Vec<f32>,
|
gammas: Vec<f32>,
|
||||||
@ -14,8 +14,8 @@ pub struct MixingResidualBlock {
|
|||||||
|
|
||||||
impl MixingResidualBlock {
|
impl MixingResidualBlock {
|
||||||
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let norm1 = WLayerNorm::new(inp)?;
|
let norm1 = LayerNormNoWeights::new(inp)?;
|
||||||
let norm2 = WLayerNorm::new(inp)?;
|
let norm2 = LayerNormNoWeights::new(inp)?;
|
||||||
let cfg = candle_nn::Conv2dConfig {
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
groups: inp,
|
groups: inp,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
Reference in New Issue
Block a user