diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 74e1836c..001b35d7 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -1,5 +1,4 @@ -#![allow(unused)] -use super::common::{AttnBlock, GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm}; +use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm}; use candle::{DType, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -223,7 +222,7 @@ impl WDiffNeXt { }; sub_blocks.push(sub_block) } - let (layer_norm, conv, start_layer_i) = if i > 0 { + let (layer_norm, conv) = if i > 0 { let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?; layer_i += 1; let cfg = candle_nn::Conv2dConfig { @@ -231,10 +230,9 @@ impl WDiffNeXt { ..Default::default() }; let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?; - layer_i += 1; - (Some(layer_norm), Some(conv), 2) + (Some(layer_norm), Some(conv)) } else { - (None, None, 0) + (None, None) }; let up_block = UpBlock { layer_norm, @@ -337,7 +335,7 @@ impl WDiffNeXt { level_outputs.reverse(); for (i, up_block) in self.up_blocks.iter().enumerate() { - let skip = match &self.effnet_mappers[self.down_blocks.len() + i] { + let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] { None => None, Some(m) => { let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; @@ -350,7 +348,12 @@ impl WDiffNeXt { } else { None }; - xs = block.res_block.forward(&xs, skip)?; + let skip = match (skip, effnet_c.as_ref()) { + (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?), + (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()), + (None, None) => None, + }; + xs = block.res_block.forward(&xs, skip.as_ref())?; xs = block.ts_block.forward(&xs, &r_embed)?; if let Some(attn_block) = &block.attn_block { xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 6589a07d..1268047a 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -1,10 +1,8 @@ -#![allow(unused)] -use super::common::{AttnBlock, ResBlock, TimestepBlock}; -use candle::{DType, Module, Result, Tensor, D}; +use candle::{Module, Result, Tensor}; use candle_nn::VarBuilder; #[derive(Debug)] -struct MixingResidualBlock { +pub struct MixingResidualBlock { norm1: candle_nn::LayerNorm, depthwise_conv: candle_nn::Conv2d, norm2: candle_nn::LayerNorm, diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs index 5dd03778..a9e3e793 100644 --- a/candle-transformers/src/models/wuerstchen/prior.rs +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -1,6 +1,5 @@ -#![allow(unused)] use super::common::{AttnBlock, ResBlock, TimestepBlock}; -use candle::{DType, Module, Result, Tensor, D}; +use candle::{DType, Result, Tensor, D}; use candle_nn::VarBuilder; #[derive(Debug)]