Simple pad support. (#336)

* Simple pad support.

* Fix the tensor indexing when padding.
This commit is contained in:
Laurent Mazare
2023-08-07 16:24:56 +02:00
committed by GitHub
parent e72ba0b9e7
commit f53a333ea9
4 changed files with 31 additions and 7 deletions

View File

@ -57,7 +57,7 @@ impl Timesteps {
Tensor::cat(&[&sin, &cos], D::Minus1)?
};
if self.num_channels % 2 == 1 {
crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
emb.pad_with_zeros(D::Minus2, 0, 1)
} else {
Ok(emb)
}

View File

@ -5,7 +5,7 @@ use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use candle::{Result, Tensor};
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
@ -39,7 +39,9 @@ impl Downsample2D {
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
Some(conv) => {
if self.padding == 0 {
let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
let xs = xs
.pad_with_zeros(D::Minus1, 0, 1)?
.pad_with_zeros(D::Minus2, 0, 1)?;
conv.forward(&xs)
} else {
conv.forward(xs)

View File

@ -4,10 +4,6 @@ pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn pad(_: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
todo!()
}