diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index fca51ef7..5b66a743 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -629,6 +629,18 @@ fn simple_eval_( let axis = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); let axis = xs.normalize_axis(axis)?; + // index_select does not support negative indices, so normalize them + // to positive indices. + let indices = &{ + let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?; + let max = Tensor::new(xs.dims()[axis] as i64, indices.device())? + .to_dtype(indices.dtype())?; + let mask = indices.lt(&zeros)?; + mask.to_dtype(indices.dtype())? + .broadcast_mul(&max)? + .add(&indices)? + }; + // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices // tensor directly, but candle does not support tensor indexing at the moment, so // some workarounds must be done.