Simple pad support. (#336)

* Simple pad support.

* Fix the tensor indexing when padding.
This commit is contained in:
Laurent Mazare
2023-08-07 16:24:56 +02:00
committed by GitHub
parent e72ba0b9e7
commit f53a333ea9
4 changed files with 31 additions and 7 deletions

View File

@ -1759,6 +1759,32 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false)) Ok(from_storage(storage, shape, op, false))
} }
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = right;
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[self, &right], dim)
} else if right == 0 {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[&left, self], dim)
} else {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
dims[dim] = right;
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[&left, self, &right], dim)
}
}
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap() self.storage.read().unwrap()
} }

View File

@ -57,7 +57,7 @@ impl Timesteps {
Tensor::cat(&[&sin, &cos], D::Minus1)? Tensor::cat(&[&sin, &cos], D::Minus1)?
}; };
if self.num_channels % 2 == 1 { if self.num_channels % 2 == 1 {
crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None) emb.pad_with_zeros(D::Minus2, 0, 1)
} else { } else {
Ok(emb) Ok(emb)
} }

View File

@ -5,7 +5,7 @@ use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
}; };
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use candle::{Result, Tensor}; use candle::{Result, Tensor, D};
use candle_nn as nn; use candle_nn as nn;
#[derive(Debug)] #[derive(Debug)]
@ -39,7 +39,9 @@ impl Downsample2D {
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None), None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
Some(conv) => { Some(conv) => {
if self.padding == 0 { if self.padding == 0 {
let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?; let xs = xs
.pad_with_zeros(D::Minus1, 0, 1)?
.pad_with_zeros(D::Minus2, 0, 1)?;
conv.forward(&xs) conv.forward(&xs)
} else { } else {
conv.forward(xs) conv.forward(xs)

View File

@ -4,10 +4,6 @@ pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
todo!() todo!()
} }
pub fn pad(_: &Tensor) -> Result<Tensor> {
todo!()
}
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> { pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
todo!() todo!()
} }