Remove the parameters for the Wuerstchen layer-norm. (#879)

* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
This commit is contained in:
Laurent Mazare
2023-09-17 15:59:27 +01:00
committed by GitHub
parent 5f83c13f17
commit 06cc329e71
5 changed files with 45 additions and 45 deletions

View File

@ -33,7 +33,7 @@ impl WPrior {
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?;
let out_ln = super::common::WLayerNorm::new(c)?;
let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
let mut blocks = Vec::with_capacity(depth);
for index in 0..depth {