mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Only use classifier free guidance for the prior. (#896)
* Only use classifier free guidance for the prior. * Add another specific layer-norm structure. * Tweaks. * Fix the latent shape. * Print the prior shape. * More shape fixes. * Remove some debugging continue.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
|
||||
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
@ -75,7 +75,7 @@ struct UpBlock {
|
||||
pub struct WDiffNeXt {
|
||||
clip_mapper: candle_nn::Linear,
|
||||
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
||||
seq_norm: WLayerNorm,
|
||||
seq_norm: LayerNormNoWeights,
|
||||
embedding_conv: candle_nn::Conv2d,
|
||||
embedding_ln: WLayerNorm,
|
||||
down_blocks: Vec<DownBlock>,
|
||||
@ -133,7 +133,7 @@ impl WDiffNeXt {
|
||||
};
|
||||
effnet_mappers.push(c)
|
||||
}
|
||||
let seq_norm = WLayerNorm::new(c_cond)?;
|
||||
let seq_norm = LayerNormNoWeights::new(c_cond)?;
|
||||
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
||||
let embedding_conv = candle_nn::conv2d(
|
||||
c_in * patch_size * patch_size,
|
||||
@ -335,6 +335,7 @@ impl WDiffNeXt {
|
||||
level_outputs.push(xs.clone())
|
||||
}
|
||||
level_outputs.reverse();
|
||||
let mut xs = level_outputs[0].clone();
|
||||
|
||||
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
||||
|
Reference in New Issue
Block a user