This commit is contained in:
Laurent Mazare
2023-09-15 16:11:11 +02:00
committed by GitHub
parent 30be5b6660
commit c2007ac88f
3 changed files with 14 additions and 14 deletions

View File

@ -1,5 +1,4 @@
#![allow(unused)] use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
use super::common::{AttnBlock, GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D}; use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
@ -223,7 +222,7 @@ impl WDiffNeXt {
}; };
sub_blocks.push(sub_block) sub_blocks.push(sub_block)
} }
let (layer_norm, conv, start_layer_i) = if i > 0 { let (layer_norm, conv) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?; let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
layer_i += 1; layer_i += 1;
let cfg = candle_nn::Conv2dConfig { let cfg = candle_nn::Conv2dConfig {
@ -231,10 +230,9 @@ impl WDiffNeXt {
..Default::default() ..Default::default()
}; };
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?; let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
layer_i += 1; (Some(layer_norm), Some(conv))
(Some(layer_norm), Some(conv), 2)
} else { } else {
(None, None, 0) (None, None)
}; };
let up_block = UpBlock { let up_block = UpBlock {
layer_norm, layer_norm,
@ -337,7 +335,7 @@ impl WDiffNeXt {
level_outputs.reverse(); level_outputs.reverse();
for (i, up_block) in self.up_blocks.iter().enumerate() { for (i, up_block) in self.up_blocks.iter().enumerate() {
let skip = match &self.effnet_mappers[self.down_blocks.len() + i] { let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
None => None, None => None,
Some(m) => { Some(m) => {
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
@ -350,7 +348,12 @@ impl WDiffNeXt {
} else { } else {
None None
}; };
xs = block.res_block.forward(&xs, skip)?; let skip = match (skip, effnet_c.as_ref()) {
(Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
(None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
(None, None) => None,
};
xs = block.res_block.forward(&xs, skip.as_ref())?;
xs = block.ts_block.forward(&xs, &r_embed)?; xs = block.ts_block.forward(&xs, &r_embed)?;
if let Some(attn_block) = &block.attn_block { if let Some(attn_block) = &block.attn_block {
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;

View File

@ -1,10 +1,8 @@
#![allow(unused)] use candle::{Module, Result, Tensor};
use super::common::{AttnBlock, ResBlock, TimestepBlock};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
#[derive(Debug)] #[derive(Debug)]
struct MixingResidualBlock { pub struct MixingResidualBlock {
norm1: candle_nn::LayerNorm, norm1: candle_nn::LayerNorm,
depthwise_conv: candle_nn::Conv2d, depthwise_conv: candle_nn::Conv2d,
norm2: candle_nn::LayerNorm, norm2: candle_nn::LayerNorm,

View File

@ -1,6 +1,5 @@
#![allow(unused)]
use super::common::{AttnBlock, ResBlock, TimestepBlock}; use super::common::{AttnBlock, ResBlock, TimestepBlock};
use candle::{DType, Module, Result, Tensor, D}; use candle::{DType, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
#[derive(Debug)] #[derive(Debug)]