ONNX: GatherElements, Xor (#2568)

This commit is contained in:
Anubhab Bandyopadhyay
2024-10-17 23:52:35 +05:30
committed by GitHub
parent dcd83336b6
commit 7c09215ef4
2 changed files with 582 additions and 0 deletions

View File

@ -670,6 +670,49 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), xs);
}
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
// A Note to fellow lurkers:
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
// and examples is incorrect.
// Use `torch.gather` for the validating/ verifying against the proper behaviour
"GatherElements" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let rank = data.rank();
if rank != indices.rank() {
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
}
let axis = {
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = data.normalize_axis(axis_i64)?;
if axis >= rank {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
axis_i64,
rank - 1
)
}
axis
};
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(indices)?
};
values.insert(node.output[0].clone(), data.gather(indices, axis)?);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
@ -1891,6 +1934,16 @@ fn simple_eval_(
);
}
}
// https://onnx.ai/onnx/operators/onnx__Xor.html
"Xor" => {
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
let a = get(&node.input[0])?.gt(0_u8)?;
let b = get(&node.input[1])?.gt(0_u8)?;
let out = a.broadcast_add(&b)?.eq(1_u8)?;
values.insert(node.output[0].clone(), out);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}