mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
DiffNeXt/unet (#859)
* DiffNeXt/unet * Start adding the vae. * VAE residual block. * VAE forward pass. * Add pixel shuffling. * Actually use pixel shuffling.
This commit is contained in:
@ -189,3 +189,27 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||
}
|
||||
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
||||
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
||||
let (b_size, c, h, w) = xs.dims4()?;
|
||||
let out_c = c / upscale_factor / upscale_factor;
|
||||
xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
|
||||
.permute((0, 1, 4, 2, 5, 3))?
|
||||
.reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
|
||||
}
|
||||
|
||||
pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
|
||||
let (b_size, c, h, w) = xs.dims4()?;
|
||||
let out_c = c * downscale_factor * downscale_factor;
|
||||
xs.reshape((
|
||||
b_size,
|
||||
c,
|
||||
h / downscale_factor,
|
||||
downscale_factor,
|
||||
w / downscale_factor,
|
||||
downscale_factor,
|
||||
))?
|
||||
.permute((0, 1, 3, 5, 2, 4))?
|
||||
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
|
||||
}
|
||||
|
Reference in New Issue
Block a user