Simple pad support. (#336)

* Simple pad support.

* Fix the tensor indexing when padding.
This commit is contained in:
Laurent Mazare
2023-08-07 16:24:56 +02:00
committed by GitHub
parent e72ba0b9e7
commit f53a333ea9
4 changed files with 31 additions and 7 deletions

View File

@ -1759,6 +1759,32 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = right;
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[self, &right], dim)
} else if right == 0 {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[&left, self], dim)
} else {
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
let mut dims = self.dims().to_vec();
dims[dim] = left;
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
dims[dim] = right;
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
Tensor::cat(&[&left, self, &right], dim)
}
}
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}