mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Line-up the wuerstchen model with the python implementation. (#901)
* Line-up the wuerstchen model with the python implementation. * Missing cos. * Fix the picture denormalization.
This commit is contained in:
@ -12,6 +12,7 @@ use candle_nn::Module;
|
||||
pub enum Activation {
|
||||
QuickGelu,
|
||||
Gelu,
|
||||
GeluErf,
|
||||
}
|
||||
|
||||
impl Module for Activation {
|
||||
@ -19,6 +20,7 @@ impl Module for Activation {
|
||||
match self {
|
||||
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||
Activation::Gelu => xs.gelu(),
|
||||
Activation::GeluErf => xs.gelu_erf(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -111,7 +113,7 @@ impl Config {
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
projection_dim: 1024,
|
||||
activation: Activation::Gelu,
|
||||
activation: Activation::GeluErf,
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,7 +128,7 @@ impl Config {
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 20,
|
||||
projection_dim: 512,
|
||||
activation: Activation::Gelu,
|
||||
activation: Activation::GeluErf,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ impl GlobalResponseNorm {
|
||||
|
||||
impl Module for GlobalResponseNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
|
||||
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let stand_div_norm =
|
||||
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
|
||||
xs.broadcast_mul(&stand_div_norm)?
|
||||
@ -152,7 +152,7 @@ impl ResBlock {
|
||||
.permute((0, 2, 3, 1))?;
|
||||
let xs = xs
|
||||
.apply(&self.channelwise_lin1)?
|
||||
.gelu()?
|
||||
.gelu_erf()?
|
||||
.apply(&self.channelwise_grn)?
|
||||
.apply(&self.channelwise_lin2)?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
|
@ -52,8 +52,10 @@ impl DDPMWScheduler {
|
||||
} else {
|
||||
t
|
||||
};
|
||||
let alpha_cumprod =
|
||||
((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod;
|
||||
let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
|
||||
.cos()
|
||||
.powi(2)
|
||||
/ self.init_alpha_cumprod;
|
||||
alpha_cumprod.clamp(0.0001, 0.9999)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user