Replication pad (#861)

* Add the embed mapper convolutions.

* Add the replication pad layer.

* Use the replication-pad op.

* Tweak a todo.
This commit is contained in:
Laurent Mazare
2023-09-15 15:06:21 +02:00
committed by GitHub
parent 107d3d9530
commit 30be5b6660
2 changed files with 17 additions and 2 deletions

View File

@ -49,7 +49,7 @@ impl Module for MixingResidualBlock {
.apply(&self.norm1)?
.permute((0, 3, 1, 2))?
.affine(1. + mods[0] as f64, mods[1] as f64)?;
// TODO: Add the ReplicationPad2d
let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
let x_temp = xs
.permute((0, 2, 3, 1))?
@ -88,10 +88,10 @@ impl PaellaVQ {
}
xs.apply(&self.down_blocks_conv)?
.apply(&self.down_blocks_bn)
// TODO: quantizer
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: quantizer if we want to support `force_not_quantize=False`.
let mut xs = xs.apply(&self.up_blocks_conv)?;
for up_block in self.up_blocks.iter() {
xs = xs.apply(&up_block.0)?;