Add the embed mapper convolutions. (#860)

This commit is contained in:
Laurent Mazare
2023-09-15 11:38:38 +02:00
committed by GitHub
parent 2746f2c4be
commit 107d3d9530

View File

@ -101,10 +101,39 @@ impl WDiffNeXt {
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
const NHEAD: [usize; 4] = [0, 10, 20, 20];
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
const EFFNET_EMBD: usize = 16;
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
// TODO: populate effnet_mappers
let effnet_mappers = vec![];
let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
let vb_e = vb.pp("effnet_mappers");
for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
let c = if inject {
Some(candle_nn::conv2d(
EFFNET_EMBD,
c_cond,
1,
Default::default(),
vb_e.pp(i),
)?)
} else {
None
};
effnet_mappers.push(c)
}
for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
let c = if inject {
Some(candle_nn::conv2d(
EFFNET_EMBD,
c_cond,
1,
Default::default(),
vb_e.pp(i + INJECT_EFFNET.len()),
)?)
} else {
None
};
effnet_mappers.push(c)
}
let cfg = candle_nn::layer_norm::LayerNormConfig {
..Default::default()
};