mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Simple pad support. (#336)
* Simple pad support. * Fix the tensor indexing when padding.
This commit is contained in:
@ -1759,6 +1759,32 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||||
|
if left == 0 && right == 0 {
|
||||||
|
Ok(self.clone())
|
||||||
|
} else if left == 0 {
|
||||||
|
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||||
|
let mut dims = self.dims().to_vec();
|
||||||
|
dims[dim] = right;
|
||||||
|
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||||
|
Tensor::cat(&[self, &right], dim)
|
||||||
|
} else if right == 0 {
|
||||||
|
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||||
|
let mut dims = self.dims().to_vec();
|
||||||
|
dims[dim] = left;
|
||||||
|
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||||
|
Tensor::cat(&[&left, self], dim)
|
||||||
|
} else {
|
||||||
|
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||||
|
let mut dims = self.dims().to_vec();
|
||||||
|
dims[dim] = left;
|
||||||
|
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||||
|
dims[dim] = right;
|
||||||
|
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||||
|
Tensor::cat(&[&left, self, &right], dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||||
self.storage.read().unwrap()
|
self.storage.read().unwrap()
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ impl Timesteps {
|
|||||||
Tensor::cat(&[&sin, &cos], D::Minus1)?
|
Tensor::cat(&[&sin, &cos], D::Minus1)?
|
||||||
};
|
};
|
||||||
if self.num_channels % 2 == 1 {
|
if self.num_channels % 2 == 1 {
|
||||||
crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
|
emb.pad_with_zeros(D::Minus2, 0, 1)
|
||||||
} else {
|
} else {
|
||||||
Ok(emb)
|
Ok(emb)
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ use crate::attention::{
|
|||||||
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
||||||
};
|
};
|
||||||
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -39,7 +39,9 @@ impl Downsample2D {
|
|||||||
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
|
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
|
||||||
Some(conv) => {
|
Some(conv) => {
|
||||||
if self.padding == 0 {
|
if self.padding == 0 {
|
||||||
let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
|
let xs = xs
|
||||||
|
.pad_with_zeros(D::Minus1, 0, 1)?
|
||||||
|
.pad_with_zeros(D::Minus2, 0, 1)?;
|
||||||
conv.forward(&xs)
|
conv.forward(&xs)
|
||||||
} else {
|
} else {
|
||||||
conv.forward(xs)
|
conv.forward(xs)
|
||||||
|
@ -4,10 +4,6 @@ pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pad(_: &Tensor) -> Result<Tensor> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
|
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user