DiffNeXt/unet (#859)

* DiffNeXt/unet

* Start adding the vae.

* VAE residual block.

* VAE forward pass.

* Add pixel shuffling.

* Actually use pixel shuffling.
This commit is contained in:
Laurent Mazare
2023-09-15 11:14:02 +02:00
committed by GitHub
parent 81a36b8713
commit 2746f2c4be
5 changed files with 216 additions and 9 deletions

View File

@ -189,3 +189,27 @@ impl candle::CustomOp1 for SoftmaxLastDim {
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(&SoftmaxLastDim)
}
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;
let out_c = c / upscale_factor / upscale_factor;
xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
.permute((0, 1, 4, 2, 5, 3))?
.reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
}
pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;
let out_c = c * downscale_factor * downscale_factor;
xs.reshape((
b_size,
c,
h / downscale_factor,
downscale_factor,
w / downscale_factor,
downscale_factor,
))?
.permute((0, 1, 3, 5, 2, 4))?
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
}