Add some missing biases. (#908)

This commit is contained in:
Laurent Mazare
2023-09-20 10:14:51 +01:00
committed by GitHub
parent c0b49d5a50
commit f685b2231c

View File

@ -106,8 +106,7 @@ impl PaellaVQ {
stride: 2, stride: 2,
..Default::default() ..Default::default()
}; };
let block = let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?;
candle_nn::conv2d_no_bias(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?;
d_idx += 1; d_idx += 1;
Some(block) Some(block)
} else { } else {
@ -130,7 +129,7 @@ impl PaellaVQ {
let mut up_blocks = Vec::new(); let mut up_blocks = Vec::new();
let vb_u = vb.pp("up_blocks"); let vb_u = vb.pp("up_blocks");
let mut u_idx = 0; let mut u_idx = 0;
let up_blocks_conv = candle_nn::conv2d_no_bias( let up_blocks_conv = candle_nn::conv2d(
LATENT_CHANNELS, LATENT_CHANNELS,
C_LEVELS[1], C_LEVELS[1],
1, 1,
@ -152,7 +151,7 @@ impl PaellaVQ {
stride: 2, stride: 2,
..Default::default() ..Default::default()
}; };
let block = candle_nn::conv_transpose2d_no_bias( let block = candle_nn::conv_transpose2d(
c_level, c_level,
C_LEVELS[C_LEVELS.len() - i - 2], C_LEVELS[C_LEVELS.len() - i - 2],
4, 4,