diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 0f028dd8..da47d180 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2135,6 +2135,46 @@ impl Tensor { } } + /// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the + /// input tensor values and `right` elements after. + pub fn pad_with_same(&self, dim: D, left: usize, right: usize) -> Result { + if left == 0 && right == 0 { + Ok(self.clone()) + } else if self.elem_count() == 0 { + crate::bail!("cannot use pad_with_same on an empty tensor") + } else if left == 0 { + let dim = dim.to_index(self.shape(), "pad_with_same")?; + let r = self.narrow(dim, self.dim(dim)? - 1, 1)?; + let mut v = vec![self]; + for _ in 0..right { + v.push(&r) + } + Tensor::cat(&v, dim) + } else if right == 0 { + let dim = dim.to_index(self.shape(), "pad_with_same")?; + let l = self.narrow(dim, 0, 1)?; + let mut v = vec![]; + for _ in 0..left { + v.push(&l) + } + v.push(self); + Tensor::cat(&v, dim) + } else { + let dim = dim.to_index(self.shape(), "pad_with_same")?; + let l = self.narrow(dim, 0, 1)?; + let r = self.narrow(dim, self.dim(dim)? - 1, 1)?; + let mut v = vec![]; + for _ in 0..left { + v.push(&l) + } + v.push(self); + for _ in 0..right { + v.push(&r) + } + Tensor::cat(&v, dim) + } + } + /// Run the `forward` method of `m` on `self`. pub fn apply(&self, m: &M) -> Result { m.forward(self) diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c938ffea..ae1bd058 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1073,3 +1073,19 @@ fn randn_hasneg() -> Result<()> { } Ok(()) } + +#[test] +fn pad_with_same() -> Result<()> { + let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?; + let t0 = t.pad_with_same(0, 1, 2)?; + assert_eq!( + t0.to_vec2::()?, + [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]] + ); + let t1 = t.pad_with_same(1, 1, 2)?; + assert_eq!( + t1.to_vec2::()?, + [[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]] + ); + Ok(()) +}