diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 8064f87f..bce68114 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -373,7 +373,6 @@ fn run(args: Args) -> Result<()> { ); let image = vqgan.decode(&(&latents * 0.3764)?)?; // TODO: Add the clamping between 0 and 1. - let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, None); candle_examples::save_image(&image, image_filename)? diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 7f86cf31..e7a20270 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -12,6 +12,7 @@ use candle_nn::Module; pub enum Activation { QuickGelu, Gelu, + GeluErf, } impl Module for Activation { @@ -19,6 +20,7 @@ impl Module for Activation { match self { Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, Activation::Gelu => xs.gelu(), + Activation::GeluErf => xs.gelu_erf(), } } } @@ -111,7 +113,7 @@ impl Config { num_hidden_layers: 24, num_attention_heads: 16, projection_dim: 1024, - activation: Activation::Gelu, + activation: Activation::GeluErf, } } @@ -126,7 +128,7 @@ impl Config { num_hidden_layers: 32, num_attention_heads: 20, projection_dim: 512, - activation: Activation::Gelu, + activation: Activation::GeluErf, } } } diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index 3cac2a59..8416a1f1 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -100,7 +100,7 @@ impl GlobalResponseNorm { impl Module for GlobalResponseNorm { fn forward(&self, xs: &Tensor) -> Result { - let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?; + let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?; let stand_div_norm = agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?; xs.broadcast_mul(&stand_div_norm)? @@ -152,7 +152,7 @@ impl ResBlock { .permute((0, 2, 3, 1))?; let xs = xs .apply(&self.channelwise_lin1)? - .gelu()? + .gelu_erf()? .apply(&self.channelwise_grn)? .apply(&self.channelwise_lin2)? .permute((0, 3, 1, 2))?; diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs index 80640072..9e69b868 100644 --- a/candle-transformers/src/models/wuerstchen/ddpm.rs +++ b/candle-transformers/src/models/wuerstchen/ddpm.rs @@ -52,8 +52,10 @@ impl DDPMWScheduler { } else { t }; - let alpha_cumprod = - ((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod; + let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5) + .cos() + .powi(2) + / self.init_alpha_cumprod; alpha_cumprod.clamp(0.0001, 0.9999) }