From 30be5b6660ca86f8ddd2cca88890cf4e40e45e12 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 15 Sep 2023 15:06:21 +0200 Subject: [PATCH] Replication pad (#861) * Add the embed mapper convolutions. * Add the replication pad layer. * Use the replication-pad op. * Tweak a todo. --- candle-nn/src/ops.rs | 15 +++++++++++++++ .../src/models/wuerstchen/paella_vq.rs | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 16b2e924..1256a076 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -213,3 +213,18 @@ pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result { .permute((0, 1, 3, 5, 2, 4))? .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor)) } + +// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html +pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result { + match pad { + 0 => Ok(xs.clone()), + 1 => { + let (_b_size, _c, h, w) = xs.dims4()?; + let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?); + let xs = Tensor::cat(&[&first, xs, &last], 3)?; + let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?); + Tensor::cat(&[&first, &xs, &last], 2) + } + n => candle::bail!("replication-pad with a size of {n} is not supported"), + } +} diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 6301b7a1..6589a07d 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -49,7 +49,7 @@ impl Module for MixingResidualBlock { .apply(&self.norm1)? .permute((0, 3, 1, 2))? .affine(1. + mods[0] as f64, mods[1] as f64)?; - // TODO: Add the ReplicationPad2d + let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?; let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?; let x_temp = xs .permute((0, 2, 3, 1))? @@ -88,10 +88,10 @@ impl PaellaVQ { } xs.apply(&self.down_blocks_conv)? .apply(&self.down_blocks_bn) - // TODO: quantizer } pub fn decode(&self, xs: &Tensor) -> Result { + // TODO: quantizer if we want to support `force_not_quantize=False`. let mut xs = xs.apply(&self.up_blocks_conv)?; for up_block in self.up_blocks.iter() { xs = xs.apply(&up_block.0)?;