mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Testcases (#2567)
This commit is contained in:

committed by
GitHub

parent
a01aa89799
commit
dcd83336b6
@ -1520,14 +1520,15 @@ impl Tensor {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - The input tensor.
|
||||
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
|
||||
/// but can have a different number of elements on the target dimension.
|
||||
/// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
|
||||
/// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
|
||||
/// * `dim` - the target dimension.
|
||||
///
|
||||
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
|
||||
/// dimension `dim` by the values in `indexes`.
|
||||
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "gather")?;
|
||||
|
||||
let self_dims = self.dims();
|
||||
let indexes_dims = indexes.dims();
|
||||
let mismatch = if indexes_dims.len() != self_dims.len() {
|
||||
@ -1535,7 +1536,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
if i != dim && d1 < d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
|
Reference in New Issue
Block a user