mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm. * Fixes. * More fixes (including conv-transpose2d. * More fixes. * Again more fixes.
This commit is contained in:
@ -1,28 +1,35 @@
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
|
||||
#[derive(Debug)]
|
||||
pub struct WLayerNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl WLayerNorm {
|
||||
pub fn new(size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
||||
eps: 1e-6,
|
||||
remove_mean: true,
|
||||
affine: false,
|
||||
};
|
||||
let inner = candle_nn::layer_norm(size, cfg, vb)?;
|
||||
Ok(Self { inner })
|
||||
pub fn new(_size: usize) -> Result<Self> {
|
||||
Ok(Self { eps: 1e-6 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for WLayerNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.permute((0, 2, 3, 1))?
|
||||
.apply(&self.inner)?
|
||||
let xs = xs.permute((0, 2, 3, 1))?;
|
||||
|
||||
let x_dtype = xs.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
|
||||
let hidden_size = xs.dim(D::Minus1)?;
|
||||
let xs = xs.to_dtype(internal_dtype)?;
|
||||
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let xs = xs.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||
.to_dtype(x_dtype)?
|
||||
.permute((0, 3, 1, 2))
|
||||
}
|
||||
}
|
||||
@ -57,8 +64,8 @@ pub struct GlobalResponseNorm {
|
||||
|
||||
impl GlobalResponseNorm {
|
||||
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?;
|
||||
let beta = vb.get((1, 1, 1, 1, dim), "beta")?;
|
||||
let gamma = vb.get((1, 1, 1, dim), "gamma")?;
|
||||
let beta = vb.get((1, 1, 1, dim), "beta")?;
|
||||
Ok(Self { gamma, beta })
|
||||
}
|
||||
}
|
||||
@ -92,7 +99,7 @@ impl ResBlock {
|
||||
..Default::default()
|
||||
};
|
||||
let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
|
||||
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
||||
let norm = WLayerNorm::new(c)?;
|
||||
let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
|
||||
let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
|
||||
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
||||
@ -141,7 +148,7 @@ impl AttnBlock {
|
||||
self_attn: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
||||
let norm = WLayerNorm::new(c)?;
|
||||
let attention = Attention::new(vb.pp("attention"), c, None, nhead, c / nhead, None, false)?;
|
||||
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
||||
Ok(Self {
|
||||
|
Reference in New Issue
Block a user