More Wuerstchen fixes. (#882)

* More Weurstchen fixes.

* More shape fixes.

* Add more of the prior specific bits.

* Broadcast add.

* Fix the clip config.

* Add some masking options to the clip model.
This commit is contained in:
Laurent Mazare
2023-09-17 22:08:11 +01:00
committed by GitHub
parent 06cc329e71
commit c2b866172a
4 changed files with 96 additions and 41 deletions

View File

@ -75,9 +75,9 @@ impl Module for GlobalResponseNorm {
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
let stand_div_norm =
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
(xs.broadcast_mul(&stand_div_norm)?
.broadcast_mul(&self.gamma)
+ &self.beta)?
xs.broadcast_mul(&stand_div_norm)?
.broadcast_mul(&self.gamma)?
.broadcast_add(&self.beta)?
+ xs
}
}