mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add the upblocks. (#853)
This commit is contained in:
@ -161,8 +161,57 @@ impl WDiffNeXt {
|
||||
down_blocks.push(down_block)
|
||||
}
|
||||
|
||||
// TODO: populate.
|
||||
let up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
||||
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
||||
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
|
||||
let vb = vb.pp("up_blocks").pp(i);
|
||||
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
||||
let mut layer_i = 0;
|
||||
for j in 0..BLOCKS[i] {
|
||||
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
|
||||
let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
|
||||
c_hidden + c_skip
|
||||
} else {
|
||||
c_skip
|
||||
};
|
||||
let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
let attn_block = if j == 0 {
|
||||
None
|
||||
} else {
|
||||
let attn_block =
|
||||
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
Some(attn_block)
|
||||
};
|
||||
let sub_block = SubBlock {
|
||||
res_block,
|
||||
ts_block,
|
||||
attn_block,
|
||||
};
|
||||
sub_blocks.push(sub_block)
|
||||
}
|
||||
let (layer_norm, conv, start_layer_i) = if i > 0 {
|
||||
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
(Some(layer_norm), Some(conv), 2)
|
||||
} else {
|
||||
(None, None, 0)
|
||||
};
|
||||
let up_block = UpBlock {
|
||||
layer_norm,
|
||||
conv,
|
||||
sub_blocks,
|
||||
};
|
||||
up_blocks.push(up_block)
|
||||
}
|
||||
|
||||
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
|
||||
let clf_conv = candle_nn::conv2d(
|
||||
|
@ -85,10 +85,9 @@ impl WPrior {
|
||||
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let x_in = xs;
|
||||
let mut xs = xs.apply(&self.projection)?;
|
||||
// TODO: leaky relu
|
||||
let c_embed = c
|
||||
.apply(&self.cond_mapper_lin1)?
|
||||
.relu()?
|
||||
.apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
|
||||
.apply(&self.cond_mapper_lin2)?;
|
||||
let r_embed = self.gen_r_embedding(r)?;
|
||||
for block in self.blocks.iter() {
|
||||
|
Reference in New Issue
Block a user