diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ce5858fa..87323a84 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2503,6 +2503,64 @@ impl Tensor { t.transpose(dim, last) } } + + /// Returns a copy of `self` where the values within `ranges` have been replaced with the + /// content of `src`. + pub fn slice_assign>( + &self, + ranges: &[D], + src: &Tensor, + ) -> Result { + let src_dims = src.dims(); + let self_dims = self.dims(); + if self_dims.len() != src_dims.len() { + crate::bail!( + "slice-assign requires input with the same rank {} <> {}", + self_dims.len(), + src_dims.len() + ) + } + if self_dims.len() != ranges.len() { + crate::bail!( + "slice-assign requires input with the same rank as there are ranges {} <> {}", + self_dims.len(), + ranges.len() + ) + } + let mut src = src.clone(); + let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; + for (i, range) in ranges.iter().enumerate() { + let start_included = match range.start_bound() { + std::ops::Bound::Unbounded => 0, + std::ops::Bound::Included(v) => *v, + std::ops::Bound::Excluded(v) => *v + 1, + }; + let end_excluded = match range.end_bound() { + std::ops::Bound::Unbounded => self_dims[i], + std::ops::Bound::Included(v) => *v + 1, + std::ops::Bound::Excluded(v) => *v, + }; + if end_excluded <= start_included { + crate::bail!( + "slice-assign: empty range for dim {i}, {start_included} {end_excluded}" + ) + } + if self_dims[i] < end_excluded { + crate::bail!( + "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", + self_dims[i] + ) + } + if end_excluded - start_included != src_dims[i] { + crate::bail!( + "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] + ) + } + src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; + mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? + } + mask.where_cond(/* on_true= */ &src, /* on_false= */ self) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/indexing_tests.rs b/candle-core/tests/indexing_tests.rs index 9c88f319..047205a3 100644 --- a/candle-core/tests/indexing_tests.rs +++ b/candle-core/tests/indexing_tests.rs @@ -91,3 +91,32 @@ fn index_3d() -> Result<()> { assert_eq!(tensor.i((1, .., 3))?.to_vec1::()?, &[15, 19, 23]); Ok(()) } + +#[test] +fn slice_assign() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[1..4, 3..5], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 3, 4], + [5, 6, 7, 0, 1], + [10, 11, 12, 2, 3], + [15, 16, 17, 4, 5] + ] + ); + let out = tensor.slice_assign(&[0..3, 0..2], &src)?; + assert_eq!( + out.to_vec2::()?, + &[ + [0, 1, 2, 3, 4], + [2, 3, 7, 8, 9], + [4, 5, 12, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +}