This commit is contained in:
Anubhab Bandyopadhyay
2024-10-17 16:30:45 +05:30
committed by GitHub
parent a01aa89799
commit dcd83336b6
2 changed files with 278 additions and 3 deletions

View File

@ -1520,14 +1520,15 @@ impl Tensor {
/// # Arguments
///
/// * `self` - The input tensor.
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
/// but can have a different number of elements on the target dimension.
/// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
/// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
/// * `dim` - the target dimension.
///
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
/// dimension `dim` by the values in `indexes`.
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
let self_dims = self.dims();
let indexes_dims = indexes.dims();
let mismatch = if indexes_dims.len() != self_dims.len() {
@ -1535,7 +1536,7 @@ impl Tensor {
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
if i != dim && d1 != d2 {
if i != dim && d1 < d2 {
mismatch = true;
break;
}