diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 33ca192e..74e1836c 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -101,10 +101,39 @@ impl WDiffNeXt { const BLOCKS: [usize; 4] = [4, 4, 14, 4]; const NHEAD: [usize; 4] = [0, 10, 20, 20]; const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; + const EFFNET_EMBD: usize = 16; let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; - // TODO: populate effnet_mappers - let effnet_mappers = vec![]; + let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len()); + let vb_e = vb.pp("effnet_mappers"); + for (i, &inject) in INJECT_EFFNET.iter().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i), + )?) + } else { + None + }; + effnet_mappers.push(c) + } + for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i + INJECT_EFFNET.len()), + )?) + } else { + None + }; + effnet_mappers.push(c) + } let cfg = candle_nn::layer_norm::LayerNormConfig { ..Default::default() };