diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 6da7362c..4a69cca0 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -106,8 +106,7 @@ impl PaellaVQ { stride: 2, ..Default::default() }; - let block = - candle_nn::conv2d_no_bias(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?; + let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?; d_idx += 1; Some(block) } else { @@ -130,7 +129,7 @@ impl PaellaVQ { let mut up_blocks = Vec::new(); let vb_u = vb.pp("up_blocks"); let mut u_idx = 0; - let up_blocks_conv = candle_nn::conv2d_no_bias( + let up_blocks_conv = candle_nn::conv2d( LATENT_CHANNELS, C_LEVELS[1], 1, @@ -152,7 +151,7 @@ impl PaellaVQ { stride: 2, ..Default::default() }; - let block = candle_nn::conv_transpose2d_no_bias( + let block = candle_nn::conv_transpose2d( c_level, C_LEVELS[C_LEVELS.len() - i - 2], 4,