mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the embed mapper convolutions. (#860)
This commit is contained in:
@ -101,10 +101,39 @@ impl WDiffNeXt {
|
||||
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
|
||||
const NHEAD: [usize; 4] = [0, 10, 20, 20];
|
||||
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
||||
const EFFNET_EMBD: usize = 16;
|
||||
|
||||
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
||||
// TODO: populate effnet_mappers
|
||||
let effnet_mappers = vec![];
|
||||
let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
|
||||
let vb_e = vb.pp("effnet_mappers");
|
||||
for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
|
||||
let c = if inject {
|
||||
Some(candle_nn::conv2d(
|
||||
EFFNET_EMBD,
|
||||
c_cond,
|
||||
1,
|
||||
Default::default(),
|
||||
vb_e.pp(i),
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
effnet_mappers.push(c)
|
||||
}
|
||||
for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
|
||||
let c = if inject {
|
||||
Some(candle_nn::conv2d(
|
||||
EFFNET_EMBD,
|
||||
c_cond,
|
||||
1,
|
||||
Default::default(),
|
||||
vb_e.pp(i + INJECT_EFFNET.len()),
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
effnet_mappers.push(c)
|
||||
}
|
||||
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
Reference in New Issue
Block a user