Replication pad (#861)

* Add the embed mapper convolutions.

* Add the replication pad layer.

* Use the replication-pad op.

* Tweak a todo.
This commit is contained in:
Laurent Mazare
2023-09-15 15:06:21 +02:00
committed by GitHub
parent 107d3d9530
commit 30be5b6660
2 changed files with 17 additions and 2 deletions

View File

@ -213,3 +213,18 @@ pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
.permute((0, 1, 3, 5, 2, 4))?
.reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
}
// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
match pad {
0 => Ok(xs.clone()),
1 => {
let (_b_size, _c, h, w) = xs.dims4()?;
let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
let xs = Tensor::cat(&[&first, xs, &last], 3)?;
let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
Tensor::cat(&[&first, &xs, &last], 2)
}
n => candle::bail!("replication-pad with a size of {n} is not supported"),
}
}