mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
W fixes. (#862)
This commit is contained in:
@ -1,5 +1,4 @@
|
|||||||
#![allow(unused)]
|
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
|
||||||
use super::common::{AttnBlock, GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
|
|
||||||
use candle::{DType, Module, Result, Tensor, D};
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
@ -223,7 +222,7 @@ impl WDiffNeXt {
|
|||||||
};
|
};
|
||||||
sub_blocks.push(sub_block)
|
sub_blocks.push(sub_block)
|
||||||
}
|
}
|
||||||
let (layer_norm, conv, start_layer_i) = if i > 0 {
|
let (layer_norm, conv) = if i > 0 {
|
||||||
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
|
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
|
||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
let cfg = candle_nn::Conv2dConfig {
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
@ -231,10 +230,9 @@ impl WDiffNeXt {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
|
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
|
||||||
layer_i += 1;
|
(Some(layer_norm), Some(conv))
|
||||||
(Some(layer_norm), Some(conv), 2)
|
|
||||||
} else {
|
} else {
|
||||||
(None, None, 0)
|
(None, None)
|
||||||
};
|
};
|
||||||
let up_block = UpBlock {
|
let up_block = UpBlock {
|
||||||
layer_norm,
|
layer_norm,
|
||||||
@ -337,7 +335,7 @@ impl WDiffNeXt {
|
|||||||
level_outputs.reverse();
|
level_outputs.reverse();
|
||||||
|
|
||||||
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||||
let skip = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
||||||
None => None,
|
None => None,
|
||||||
Some(m) => {
|
Some(m) => {
|
||||||
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
||||||
@ -350,7 +348,12 @@ impl WDiffNeXt {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
xs = block.res_block.forward(&xs, skip)?;
|
let skip = match (skip, effnet_c.as_ref()) {
|
||||||
|
(Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
|
||||||
|
(None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
|
||||||
|
(None, None) => None,
|
||||||
|
};
|
||||||
|
xs = block.res_block.forward(&xs, skip.as_ref())?;
|
||||||
xs = block.ts_block.forward(&xs, &r_embed)?;
|
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||||
if let Some(attn_block) = &block.attn_block {
|
if let Some(attn_block) = &block.attn_block {
|
||||||
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
#![allow(unused)]
|
use candle::{Module, Result, Tensor};
|
||||||
use super::common::{AttnBlock, ResBlock, TimestepBlock};
|
|
||||||
use candle::{DType, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct MixingResidualBlock {
|
pub struct MixingResidualBlock {
|
||||||
norm1: candle_nn::LayerNorm,
|
norm1: candle_nn::LayerNorm,
|
||||||
depthwise_conv: candle_nn::Conv2d,
|
depthwise_conv: candle_nn::Conv2d,
|
||||||
norm2: candle_nn::LayerNorm,
|
norm2: candle_nn::LayerNorm,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use super::common::{AttnBlock, ResBlock, TimestepBlock};
|
use super::common::{AttnBlock, ResBlock, TimestepBlock};
|
||||||
use candle::{DType, Module, Result, Tensor, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
Reference in New Issue
Block a user