mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Line-up the wuerstchen model with the python implementation. (#901)
* Line-up the wuerstchen model with the python implementation. * Missing cos. * Fix the picture denormalization.
This commit is contained in:
@ -373,7 +373,6 @@ fn run(args: Args) -> Result<()> {
|
|||||||
);
|
);
|
||||||
let image = vqgan.decode(&(&latents * 0.3764)?)?;
|
let image = vqgan.decode(&(&latents * 0.3764)?)?;
|
||||||
// TODO: Add the clamping between 0 and 1.
|
// TODO: Add the clamping between 0 and 1.
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
|
||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||||
candle_examples::save_image(&image, image_filename)?
|
candle_examples::save_image(&image, image_filename)?
|
||||||
|
@ -12,6 +12,7 @@ use candle_nn::Module;
|
|||||||
pub enum Activation {
|
pub enum Activation {
|
||||||
QuickGelu,
|
QuickGelu,
|
||||||
Gelu,
|
Gelu,
|
||||||
|
GeluErf,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Activation {
|
impl Module for Activation {
|
||||||
@ -19,6 +20,7 @@ impl Module for Activation {
|
|||||||
match self {
|
match self {
|
||||||
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||||
Activation::Gelu => xs.gelu(),
|
Activation::Gelu => xs.gelu(),
|
||||||
|
Activation::GeluErf => xs.gelu_erf(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -111,7 +113,7 @@ impl Config {
|
|||||||
num_hidden_layers: 24,
|
num_hidden_layers: 24,
|
||||||
num_attention_heads: 16,
|
num_attention_heads: 16,
|
||||||
projection_dim: 1024,
|
projection_dim: 1024,
|
||||||
activation: Activation::Gelu,
|
activation: Activation::GeluErf,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,7 +128,7 @@ impl Config {
|
|||||||
num_hidden_layers: 32,
|
num_hidden_layers: 32,
|
||||||
num_attention_heads: 20,
|
num_attention_heads: 20,
|
||||||
projection_dim: 512,
|
projection_dim: 512,
|
||||||
activation: Activation::Gelu,
|
activation: Activation::GeluErf,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ impl GlobalResponseNorm {
|
|||||||
|
|
||||||
impl Module for GlobalResponseNorm {
|
impl Module for GlobalResponseNorm {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
|
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
let stand_div_norm =
|
let stand_div_norm =
|
||||||
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
|
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
|
||||||
xs.broadcast_mul(&stand_div_norm)?
|
xs.broadcast_mul(&stand_div_norm)?
|
||||||
@ -152,7 +152,7 @@ impl ResBlock {
|
|||||||
.permute((0, 2, 3, 1))?;
|
.permute((0, 2, 3, 1))?;
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.apply(&self.channelwise_lin1)?
|
.apply(&self.channelwise_lin1)?
|
||||||
.gelu()?
|
.gelu_erf()?
|
||||||
.apply(&self.channelwise_grn)?
|
.apply(&self.channelwise_grn)?
|
||||||
.apply(&self.channelwise_lin2)?
|
.apply(&self.channelwise_lin2)?
|
||||||
.permute((0, 3, 1, 2))?;
|
.permute((0, 3, 1, 2))?;
|
||||||
|
@ -52,8 +52,10 @@ impl DDPMWScheduler {
|
|||||||
} else {
|
} else {
|
||||||
t
|
t
|
||||||
};
|
};
|
||||||
let alpha_cumprod =
|
let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
|
||||||
((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod;
|
.cos()
|
||||||
|
.powi(2)
|
||||||
|
/ self.init_alpha_cumprod;
|
||||||
alpha_cumprod.clamp(0.0001, 0.9999)
|
alpha_cumprod.clamp(0.0001, 0.9999)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user