onnx: support negative index in Gather (#2440)

index_select does not support negative indexing, but
this change adds just enough workarounds in onnx to
allow evaluating silero-vad models (which make use of
negative indices).
This commit is contained in:
shua
2024-08-22 15:28:25 +02:00
committed by GitHub
parent a8288b7a72
commit 1e96b8b695

View File

@ -629,6 +629,18 @@ fn simple_eval_(
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0); let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?; let axis = xs.normalize_axis(axis)?;
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(xs.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
// tensor directly, but candle does not support tensor indexing at the moment, so // tensor directly, but candle does not support tensor indexing at the moment, so
// some workarounds must be done. // some workarounds must be done.