mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Replication pad (#861)
* Add the embed mapper convolutions. * Add the replication pad layer. * Use the replication-pad op. * Tweak a todo.
This commit is contained in:
@ -213,3 +213,18 @@ pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
|
||||
.permute((0, 1, 3, 5, 2, 4))?
|
||||
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
|
||||
}
|
||||
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
|
||||
pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
|
||||
match pad {
|
||||
0 => Ok(xs.clone()),
|
||||
1 => {
|
||||
let (_b_size, _c, h, w) = xs.dims4()?;
|
||||
let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
|
||||
let xs = Tensor::cat(&[&first, xs, &last], 3)?;
|
||||
let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
|
||||
Tensor::cat(&[&first, &xs, &last], 2)
|
||||
}
|
||||
n => candle::bail!("replication-pad with a size of {n} is not supported"),
|
||||
}
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ impl Module for MixingResidualBlock {
|
||||
.apply(&self.norm1)?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.affine(1. + mods[0] as f64, mods[1] as f64)?;
|
||||
// TODO: Add the ReplicationPad2d
|
||||
let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
|
||||
let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
|
||||
let x_temp = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
@ -88,10 +88,10 @@ impl PaellaVQ {
|
||||
}
|
||||
xs.apply(&self.down_blocks_conv)?
|
||||
.apply(&self.down_blocks_bn)
|
||||
// TODO: quantizer
|
||||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: quantizer if we want to support `force_not_quantize=False`.
|
||||
let mut xs = xs.apply(&self.up_blocks_conv)?;
|
||||
for up_block in self.up_blocks.iter() {
|
||||
xs = xs.apply(&up_block.0)?;
|
||||
|
Reference in New Issue
Block a user