diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 0a64a5a6..52effdcf 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -110,7 +110,7 @@ impl ToUsize2 for (usize, usize) { } // A simple trait defining a module with forward method using a single argument. -pub trait Module: std::fmt::Debug { +pub trait Module { fn forward(&self, xs: &Tensor) -> Result; } @@ -119,3 +119,9 @@ impl Module for quantized::QMatMul { self.forward(xs) } } + +impl Result> Module for T { + fn forward(&self, xs: &Tensor) -> Result { + self(xs) + } +} diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index adf1451c..c4055792 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -44,6 +44,10 @@ pub fn sigmoid(xs: &Tensor) -> Result { (xs.neg()?.exp()? + 1.0)?.recip() } +pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result { + xs.relu()?.minimum(&(xs * negative_slope)?) +} + pub fn dropout(xs: &Tensor, drop_p: f32) -> Result { // This implementation is inefficient as it stores the full mask for the backward pass. // Instead we could just store the seed and have a specialized kernel that would both diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 5e49437c..7289a54d 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -161,8 +161,57 @@ impl WDiffNeXt { down_blocks.push(down_block) } - // TODO: populate. - let up_blocks = Vec::with_capacity(C_HIDDEN.len()); + let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { + let vb = vb.pp("up_blocks").pp(i); + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = 0; + for j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 { + c_hidden + c_skip + } else { + c_skip + }; + let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if j == 0 { + None + } else { + let attn_block = + AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let (layer_norm, conv, start_layer_i) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?; + layer_i += 1; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?; + layer_i += 1; + (Some(layer_norm), Some(conv), 2) + } else { + (None, None, 0) + }; + let up_block = UpBlock { + layer_norm, + conv, + sub_blocks, + }; + up_blocks.push(up_block) + } let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?; let clf_conv = candle_nn::conv2d( diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs index eea70a02..5dd03778 100644 --- a/candle-transformers/src/models/wuerstchen/prior.rs +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -85,10 +85,9 @@ impl WPrior { pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result { let x_in = xs; let mut xs = xs.apply(&self.projection)?; - // TODO: leaky relu let c_embed = c .apply(&self.cond_mapper_lin1)? - .relu()? + .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))? .apply(&self.cond_mapper_lin2)?; let r_embed = self.gen_r_embedding(r)?; for block in self.blocks.iter() {