mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Simple pad support. (#336)
* Simple pad support. * Fix the tensor indexing when padding.
This commit is contained in:
@ -1759,6 +1759,32 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
if left == 0 && right == 0 {
|
||||
Ok(self.clone())
|
||||
} else if left == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = right;
|
||||
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[self, &right], dim)
|
||||
} else if right == 0 {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = left;
|
||||
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[&left, self], dim)
|
||||
} else {
|
||||
let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = left;
|
||||
let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
dims[dim] = right;
|
||||
let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
|
||||
Tensor::cat(&[&left, self, &right], dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
Reference in New Issue
Block a user