ONNX: GatherElements, Xor (#2568)

This commit is contained in:
Anubhab Bandyopadhyay
2024-10-17 23:52:35 +05:30
committed by GitHub
parent dcd83336b6
commit 7c09215ef4
2 changed files with 582 additions and 0 deletions

View File

@ -670,6 +670,49 @@ fn simple_eval_(
}; };
values.insert(node.output[0].clone(), xs); values.insert(node.output[0].clone(), xs);
} }
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
// A Note to fellow lurkers:
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
// and examples is incorrect.
// Use `torch.gather` for the validating/ verifying against the proper behaviour
"GatherElements" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let rank = data.rank();
if rank != indices.rank() {
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
}
let axis = {
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = data.normalize_axis(axis_i64)?;
if axis >= rank {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
axis_i64,
rank - 1
)
}
axis
};
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(indices)?
};
values.insert(node.output[0].clone(), data.gather(indices, axis)?);
}
"Shape" => { "Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?; let xs = get(&node.input[0])?;
@ -1891,6 +1934,16 @@ fn simple_eval_(
); );
} }
} }
// https://onnx.ai/onnx/operators/onnx__Xor.html
"Xor" => {
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
let a = get(&node.input[0])?.gt(0_u8)?;
let b = get(&node.input[1])?.gt(0_u8)?;
let out = a.broadcast_add(&b)?.eq(1_u8)?;
values.insert(node.output[0].clone(), out);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"), op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
} }
} }

View File

@ -1159,6 +1159,163 @@ fn test_gather_operation() -> Result<()> {
Ok(()) Ok(())
} }
// GatherElements
#[test]
fn test_gather_elements() -> Result<()> {
// all the tests below are verified against `torch.gather()`
// Rank 1 index
test(&[1.0, 2.0, 3.0, 4.0], &[3i64], 0, &[4.0])?;
// Rank 2 index
test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 1, &[[4.0]])?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_0
test(
&[[1., 2.], [3., 4.]],
&[[0i64, 0], [1, 0]],
1,
&[[1., 1.], [4., 3.]],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_1
test(
&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
&[[1i64, 2, 0], [2, 0, 0]],
0,
&[[4., 8., 3.], [7., 2., 3.]],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_negative_indices
test(
&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
&[[-1_i64, -2, 0], [-2, 0, 0]],
0,
&[[7., 5., 3.], [4., 2., 3.]],
)?;
test(
&[[1.0], [2.0], [3.0], [4.0]],
&[[3i64], [2]],
0,
&[[4.], [3.]],
)?;
// Rank 3
test(
&[
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]],
],
&[[[1i64]]],
0,
&[[[5.]]],
)?;
test(
&[
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]],
],
&[[[1i64]]],
1,
&[[[3.]]],
)?;
test(
&[
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]],
],
&[[[1i64], [0]]],
2,
&[[[2.], [3.]]],
)?;
// Error cases
// Invalid index
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 0, &[[1., 2., 3., 4.]]).is_err());
// Invalid axis/ dim
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 2, &[[1., 2., 3., 4.]]).is_err());
// Invalid rank
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[3i64], 0, &[[1.]]).is_err());
fn test(
data: impl NdArray,
indices: impl NdArray,
axis: i64,
expected: impl NdArray,
) -> Result<()> {
let att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
i: axis,
doc_string: "axis".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "GatherElements".to_string(),
domain: "".to_string(),
attribute: vec![att_axis],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Size" // "Size"
#[test] #[test]
fn test_size_operation() -> Result<()> { fn test_size_operation() -> Result<()> {
@ -5340,3 +5497,375 @@ fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
Ok(()) Ok(())
} }
// Xor
#[test]
fn test_xor() -> Result<()> {
// tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor
// 2d
test(
&[[0_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1]],
&[[1_u8, 1, 0, 0], [1, 0, 0, 1], [1, 1, 1, 0]],
&[[1_u8, 0, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1]],
)?;
// 3d
test(
&[
[
[0_u8, 1, 1, 1, 1],
[0, 1, 1, 0, 0],
[1, 1, 1, 1, 1],
[0, 0, 0, 0, 1],
],
[
[0, 0, 1, 1, 1],
[1, 0, 1, 1, 1],
[1, 1, 0, 0, 1],
[1, 0, 0, 1, 0],
],
[
[1, 0, 0, 1, 1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 1],
[1, 0, 0, 0, 1],
],
],
&[
[
[1_u8, 0, 0, 1, 1],
[0, 0, 1, 0, 1],
[1, 0, 0, 1, 0],
[0, 0, 0, 0, 0],
],
[
[1, 0, 0, 1, 1],
[1, 0, 1, 1, 1],
[0, 1, 0, 1, 1],
[1, 1, 1, 0, 0],
],
[
[0, 1, 1, 1, 0],
[1, 1, 0, 1, 0],
[0, 1, 1, 1, 0],
[1, 1, 0, 1, 0],
],
],
&[
[
[1_u8, 1, 1, 0, 0],
[0, 1, 0, 0, 1],
[0, 1, 1, 0, 1],
[0, 0, 0, 0, 1],
],
[
[1, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 0],
[0, 1, 1, 1, 0],
],
[
[1, 1, 1, 0, 1],
[0, 0, 1, 1, 0],
[1, 0, 1, 1, 1],
[0, 1, 0, 1, 1],
],
],
)?;
// 4d
test(
&[
[
[[0_u8, 1, 1, 0], [1, 0, 0, 0], [1, 1, 0, 1]],
[[1, 1, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1]],
],
[
[[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 0]],
[[1, 0, 0, 1], [1, 0, 1, 1], [1, 1, 0, 1]],
],
],
&[
[
[[1_u8, 0, 1, 0], [0, 0, 1, 1], [1, 0, 1, 0]],
[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]],
],
[
[[1, 1, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0]],
[[0, 0, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1]],
],
],
&[
[
[[1_u8, 1, 0, 0], [1, 0, 1, 1], [0, 1, 1, 1]],
[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 0]],
],
[
[[0, 0, 1, 0], [1, 0, 1, 1], [1, 0, 1, 0]],
[[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 1, 0]],
],
],
)?;
// tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor_broadcast
// 3d vs 1d
test(
// Shape (3, 4, 5)
&[
[
[0_u8, 0, 0, 0, 1],
[0, 1, 0, 1, 1],
[1, 0, 0, 1, 1],
[0, 0, 1, 0, 1],
],
[
[0, 1, 0, 1, 1],
[1, 1, 0, 0, 1],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 1],
],
[
[1, 1, 0, 1, 1],
[0, 0, 0, 1, 1],
[0, 1, 1, 0, 1],
[1, 1, 0, 1, 1],
],
],
// shape (5)
&[1_u8, 0, 0, 1, 1],
// shape (3, 4, 5)
&[
[
[1_u8, 0, 0, 1, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 1, 0],
],
[
[1, 1, 0, 0, 0],
[0, 1, 0, 1, 0],
[1, 1, 1, 0, 1],
[1, 0, 0, 1, 0],
],
[
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 1, 1, 1, 0],
[0, 1, 0, 0, 0],
],
],
)?;
// 3d vs 2d
test(
// Shape (3, 4, 5)
&[
[
[0_u8, 0, 0, 0, 1],
[0, 1, 0, 1, 1],
[1, 0, 0, 1, 1],
[0, 0, 1, 0, 1],
],
[
[0, 1, 0, 1, 1],
[1, 1, 0, 0, 1],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 1],
],
[
[1, 1, 0, 1, 1],
[0, 0, 0, 1, 1],
[0, 1, 1, 0, 1],
[1, 1, 0, 1, 1],
],
],
// shape (4, 5)
&[
[0_u8, 1, 0, 1, 0],
[0, 0, 1, 0, 0],
[1, 1, 0, 1, 1],
[1, 1, 0, 1, 0],
],
// shape (3, 4, 5)
&[
[
[0_u8, 1, 0, 1, 1],
[0, 1, 1, 1, 1],
[0, 1, 0, 0, 0],
[1, 1, 1, 1, 1],
],
[
[0, 0, 0, 0, 1],
[1, 1, 1, 0, 1],
[1, 0, 1, 0, 1],
[1, 1, 0, 1, 1],
],
[
[1, 0, 0, 0, 1],
[0, 0, 1, 1, 1],
[1, 0, 1, 1, 0],
[0, 0, 0, 0, 1],
],
],
)?;
// 4d vs 2d
test(
// Shape (2, 3, 3, 4)
&[
[
[[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]],
[[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]],
[[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]],
],
[
[[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]],
[[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]],
[[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]],
],
],
// shape (3, 4)
&[[0_u8, 0, 1, 1], [1, 1, 1, 1], [0, 1, 0, 1]],
// shape (2, 3, 3, 4)
&[
[
[[1_u8, 0, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]],
[[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0]],
[[1, 0, 1, 1], [0, 0, 0, 1], [0, 1, 1, 0]],
],
[
[[0, 1, 1, 0], [0, 0, 1, 0], [1, 1, 1, 0]],
[[1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 0]],
[[1, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0]],
],
],
)?;
// 4d vs 3d
test(
// Shape (2, 3, 3, 4)
&[
[
[[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]],
[[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]],
[[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]],
],
[
[[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]],
[[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]],
[[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]],
],
],
// shape (3, 3, 4)
&[
[[1_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 0, 0]],
[[0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]],
[[0, 1, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]],
],
// shape (2, 3, 3, 4)
&[
[
[[0_u8, 1, 0, 1], [1, 1, 1, 1], [0, 0, 0, 0]],
[[1, 0, 0, 1], [0, 1, 0, 0], [1, 1, 0, 0]],
[[1, 1, 1, 0], [0, 1, 0, 1], [1, 1, 1, 0]],
],
[
[[1, 0, 0, 1], [1, 1, 1, 0], [1, 1, 1, 1]],
[[1, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]],
[[1, 1, 1, 0], [0, 1, 1, 1], [1, 0, 0, 0]],
],
],
)?;
// 4d vs 4d
test(
// Shape (1, 4, 1, 2)
&[[[[1_u8, 0]], [[1, 0]], [[1, 0]], [[1, 1]]]],
// shape (2, 1, 4, 2)
&[
[[[0_u8, 0], [1, 1], [1, 1], [1, 1]]],
[[[0, 1], [1, 0], [0, 1], [0, 0]]],
],
// shape (2, 4, 4, 2)
&[
[
[[1_u8, 0], [0, 1], [0, 1], [0, 1]],
[[1, 0], [0, 1], [0, 1], [0, 1]],
[[1, 0], [0, 1], [0, 1], [0, 1]],
[[1, 1], [0, 0], [0, 0], [0, 0]],
],
[
[[1, 1], [0, 0], [1, 1], [1, 0]],
[[1, 1], [0, 0], [1, 1], [1, 0]],
[[1, 1], [0, 0], [1, 1], [1, 0]],
[[1, 0], [0, 1], [1, 0], [1, 1]],
],
],
)?;
fn test(input: impl NdArray, other: impl NdArray, expected: impl NdArray) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Xor".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let inputs: HashMap<String, Tensor> = HashMap::from([
(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?),
(INPUT_Y.to_string(), Tensor::new(other, &Device::Cpu)?),
]);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
0 => {
assert_eq!(z.to_vec0::<u8>()?, expected.to_vec0::<u8>()?)
}
1 => {
assert_eq!(z.to_vec1::<u8>()?, expected.to_vec1::<u8>()?)
}
2 => {
assert_eq!(z.to_vec2::<u8>()?, expected.to_vec2::<u8>()?)
}
3 => {
assert_eq!(z.to_vec3::<u8>()?, expected.to_vec3::<u8>()?)
}
4 => {
// Candle has no method equivallent to `to_vec4()`
// So, as a hack, we flatten it to a single dim vec to test the results
assert_eq!(
z.flatten_all()?.to_vec1::<u8>()?,
expected.flatten_all()?.to_vec1::<u8>()?
)
}
_ => unreachable!(),
};
Ok(())
}
Ok(())
}