mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the embed mapper convolutions. (#860)
This commit is contained in:
@ -101,10 +101,39 @@ impl WDiffNeXt {
|
|||||||
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
|
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
|
||||||
const NHEAD: [usize; 4] = [0, 10, 20, 20];
|
const NHEAD: [usize; 4] = [0, 10, 20, 20];
|
||||||
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
||||||
|
const EFFNET_EMBD: usize = 16;
|
||||||
|
|
||||||
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
||||||
// TODO: populate effnet_mappers
|
let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
|
||||||
let effnet_mappers = vec![];
|
let vb_e = vb.pp("effnet_mappers");
|
||||||
|
for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
|
||||||
|
let c = if inject {
|
||||||
|
Some(candle_nn::conv2d(
|
||||||
|
EFFNET_EMBD,
|
||||||
|
c_cond,
|
||||||
|
1,
|
||||||
|
Default::default(),
|
||||||
|
vb_e.pp(i),
|
||||||
|
)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
effnet_mappers.push(c)
|
||||||
|
}
|
||||||
|
for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
|
||||||
|
let c = if inject {
|
||||||
|
Some(candle_nn::conv2d(
|
||||||
|
EFFNET_EMBD,
|
||||||
|
c_cond,
|
||||||
|
1,
|
||||||
|
Default::default(),
|
||||||
|
vb_e.pp(i + INJECT_EFFNET.len()),
|
||||||
|
)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
effnet_mappers.push(c)
|
||||||
|
}
|
||||||
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user