Line-up the wuerstchen model with the python implementation. (#901)

* Line-up the wuerstchen model with the python implementation.

* Missing cos.

* Fix the picture denormalization.
This commit is contained in:
Laurent Mazare
2023-09-19 21:59:44 +01:00
committed by GitHub
parent 7ad82b87e4
commit 67a486d18d
4 changed files with 10 additions and 7 deletions

View File

@ -100,7 +100,7 @@ impl GlobalResponseNorm {
impl Module for GlobalResponseNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let stand_div_norm =
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
xs.broadcast_mul(&stand_div_norm)?
@ -152,7 +152,7 @@ impl ResBlock {
.permute((0, 2, 3, 1))?;
let xs = xs
.apply(&self.channelwise_lin1)?
.gelu()?
.gelu_erf()?
.apply(&self.channelwise_grn)?
.apply(&self.channelwise_lin2)?
.permute((0, 3, 1, 2))?;

View File

@ -52,8 +52,10 @@ impl DDPMWScheduler {
} else {
t
};
let alpha_cumprod =
((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod;
let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
.cos()
.powi(2)
/ self.init_alpha_cumprod;
alpha_cumprod.clamp(0.0001, 0.9999)
}