mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
ONNX: GatherElements, Xor (#2568)
This commit is contained in:

committed by
GitHub

parent
dcd83336b6
commit
7c09215ef4
@ -670,6 +670,49 @@ fn simple_eval_(
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
|
||||
// A Note to fellow lurkers:
|
||||
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
|
||||
// and examples is incorrect.
|
||||
// Use `torch.gather` for the validating/ verifying against the proper behaviour
|
||||
"GatherElements" => {
|
||||
let data = get(&node.input[0])?;
|
||||
let indices = get(&node.input[1])?;
|
||||
|
||||
let rank = data.rank();
|
||||
if rank != indices.rank() {
|
||||
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
|
||||
}
|
||||
|
||||
let axis = {
|
||||
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = data.normalize_axis(axis_i64)?;
|
||||
|
||||
if axis >= rank {
|
||||
bail!(
|
||||
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
|
||||
axis_i64,
|
||||
rank - 1
|
||||
)
|
||||
}
|
||||
|
||||
axis
|
||||
};
|
||||
|
||||
// index_select does not support negative indices, so normalize them
|
||||
// to positive indices.
|
||||
let indices = &{
|
||||
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
|
||||
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
|
||||
.to_dtype(indices.dtype())?;
|
||||
let mask = indices.lt(&zeros)?;
|
||||
mask.to_dtype(indices.dtype())?
|
||||
.broadcast_mul(&max)?
|
||||
.add(indices)?
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), data.gather(indices, axis)?);
|
||||
}
|
||||
"Shape" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
||||
let xs = get(&node.input[0])?;
|
||||
@ -1891,6 +1934,16 @@ fn simple_eval_(
|
||||
);
|
||||
}
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__Xor.html
|
||||
"Xor" => {
|
||||
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
|
||||
let a = get(&node.input[0])?.gt(0_u8)?;
|
||||
let b = get(&node.input[1])?.gt(0_u8)?;
|
||||
|
||||
let out = a.broadcast_add(&b)?.eq(1_u8)?;
|
||||
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -1159,6 +1159,163 @@ fn test_gather_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// GatherElements
|
||||
#[test]
|
||||
fn test_gather_elements() -> Result<()> {
|
||||
// all the tests below are verified against `torch.gather()`
|
||||
|
||||
// Rank 1 index
|
||||
test(&[1.0, 2.0, 3.0, 4.0], &[3i64], 0, &[4.0])?;
|
||||
|
||||
// Rank 2 index
|
||||
test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 1, &[[4.0]])?;
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_0
|
||||
test(
|
||||
&[[1., 2.], [3., 4.]],
|
||||
&[[0i64, 0], [1, 0]],
|
||||
1,
|
||||
&[[1., 1.], [4., 3.]],
|
||||
)?;
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_1
|
||||
test(
|
||||
&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
|
||||
&[[1i64, 2, 0], [2, 0, 0]],
|
||||
0,
|
||||
&[[4., 8., 3.], [7., 2., 3.]],
|
||||
)?;
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_negative_indices
|
||||
test(
|
||||
&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
|
||||
&[[-1_i64, -2, 0], [-2, 0, 0]],
|
||||
0,
|
||||
&[[7., 5., 3.], [4., 2., 3.]],
|
||||
)?;
|
||||
test(
|
||||
&[[1.0], [2.0], [3.0], [4.0]],
|
||||
&[[3i64], [2]],
|
||||
0,
|
||||
&[[4.], [3.]],
|
||||
)?;
|
||||
|
||||
// Rank 3
|
||||
test(
|
||||
&[
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
],
|
||||
&[[[1i64]]],
|
||||
0,
|
||||
&[[[5.]]],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
],
|
||||
&[[[1i64]]],
|
||||
1,
|
||||
&[[[3.]]],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
],
|
||||
&[[[1i64], [0]]],
|
||||
2,
|
||||
&[[[2.], [3.]]],
|
||||
)?;
|
||||
|
||||
// Error cases
|
||||
// Invalid index
|
||||
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 0, &[[1., 2., 3., 4.]]).is_err());
|
||||
// Invalid axis/ dim
|
||||
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 2, &[[1., 2., 3., 4.]]).is_err());
|
||||
// Invalid rank
|
||||
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[3i64], 0, &[[1.]]).is_err());
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
indices: impl NdArray,
|
||||
axis: i64,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let att_axis = AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
ref_attr_name: "axis".to_string(),
|
||||
i: axis,
|
||||
doc_string: "axis".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "GatherElements".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_axis],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Size"
|
||||
#[test]
|
||||
fn test_size_operation() -> Result<()> {
|
||||
@ -5340,3 +5497,375 @@ fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Xor
|
||||
#[test]
|
||||
fn test_xor() -> Result<()> {
|
||||
// tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor
|
||||
|
||||
// 2d
|
||||
test(
|
||||
&[[0_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1]],
|
||||
&[[1_u8, 1, 0, 0], [1, 0, 0, 1], [1, 1, 1, 0]],
|
||||
&[[1_u8, 0, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1]],
|
||||
)?;
|
||||
|
||||
// 3d
|
||||
test(
|
||||
&[
|
||||
[
|
||||
[0_u8, 1, 1, 1, 1],
|
||||
[0, 1, 1, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 1, 1],
|
||||
[1, 0, 1, 1, 1],
|
||||
[1, 1, 0, 0, 1],
|
||||
[1, 0, 0, 1, 0],
|
||||
],
|
||||
[
|
||||
[1, 0, 0, 1, 1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 1],
|
||||
[1, 0, 0, 0, 1],
|
||||
],
|
||||
],
|
||||
&[
|
||||
[
|
||||
[1_u8, 0, 0, 1, 1],
|
||||
[0, 0, 1, 0, 1],
|
||||
[1, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
],
|
||||
[
|
||||
[1, 0, 0, 1, 1],
|
||||
[1, 0, 1, 1, 1],
|
||||
[0, 1, 0, 1, 1],
|
||||
[1, 1, 1, 0, 0],
|
||||
],
|
||||
[
|
||||
[0, 1, 1, 1, 0],
|
||||
[1, 1, 0, 1, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[1, 1, 0, 1, 0],
|
||||
],
|
||||
],
|
||||
&[
|
||||
[
|
||||
[1_u8, 1, 1, 0, 0],
|
||||
[0, 1, 0, 0, 1],
|
||||
[0, 1, 1, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
],
|
||||
[
|
||||
[1, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 1, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
],
|
||||
[
|
||||
[1, 1, 1, 0, 1],
|
||||
[0, 0, 1, 1, 0],
|
||||
[1, 0, 1, 1, 1],
|
||||
[0, 1, 0, 1, 1],
|
||||
],
|
||||
],
|
||||
)?;
|
||||
|
||||
// 4d
|
||||
test(
|
||||
&[
|
||||
[
|
||||
[[0_u8, 1, 1, 0], [1, 0, 0, 0], [1, 1, 0, 1]],
|
||||
[[1, 1, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1]],
|
||||
],
|
||||
[
|
||||
[[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 0]],
|
||||
[[1, 0, 0, 1], [1, 0, 1, 1], [1, 1, 0, 1]],
|
||||
],
|
||||
],
|
||||
&[
|
||||
[
|
||||
[[1_u8, 0, 1, 0], [0, 0, 1, 1], [1, 0, 1, 0]],
|
||||
[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]],
|
||||
],
|
||||
[
|
||||
[[1, 1, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0]],
|
||||
[[0, 0, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1]],
|
||||
],
|
||||
],
|
||||
&[
|
||||
[
|
||||
[[1_u8, 1, 0, 0], [1, 0, 1, 1], [0, 1, 1, 1]],
|
||||
[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 0]],
|
||||
],
|
||||
[
|
||||
[[0, 0, 1, 0], [1, 0, 1, 1], [1, 0, 1, 0]],
|
||||
[[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 1, 0]],
|
||||
],
|
||||
],
|
||||
)?;
|
||||
|
||||
// tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor_broadcast
|
||||
// 3d vs 1d
|
||||
test(
|
||||
// Shape (3, 4, 5)
|
||||
&[
|
||||
[
|
||||
[0_u8, 0, 0, 0, 1],
|
||||
[0, 1, 0, 1, 1],
|
||||
[1, 0, 0, 1, 1],
|
||||
[0, 0, 1, 0, 1],
|
||||
],
|
||||
[
|
||||
[0, 1, 0, 1, 1],
|
||||
[1, 1, 0, 0, 1],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
],
|
||||
[
|
||||
[1, 1, 0, 1, 1],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 1, 1, 0, 1],
|
||||
[1, 1, 0, 1, 1],
|
||||
],
|
||||
],
|
||||
// shape (5)
|
||||
&[1_u8, 0, 0, 1, 1],
|
||||
// shape (3, 4, 5)
|
||||
&[
|
||||
[
|
||||
[1_u8, 0, 0, 1, 0],
|
||||
[1, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 0, 1, 1, 0],
|
||||
],
|
||||
[
|
||||
[1, 1, 0, 0, 0],
|
||||
[0, 1, 0, 1, 0],
|
||||
[1, 1, 1, 0, 1],
|
||||
[1, 0, 0, 1, 0],
|
||||
],
|
||||
[
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
],
|
||||
],
|
||||
)?;
|
||||
|
||||
// 3d vs 2d
|
||||
test(
|
||||
// Shape (3, 4, 5)
|
||||
&[
|
||||
[
|
||||
[0_u8, 0, 0, 0, 1],
|
||||
[0, 1, 0, 1, 1],
|
||||
[1, 0, 0, 1, 1],
|
||||
[0, 0, 1, 0, 1],
|
||||
],
|
||||
[
|
||||
[0, 1, 0, 1, 1],
|
||||
[1, 1, 0, 0, 1],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
],
|
||||
[
|
||||
[1, 1, 0, 1, 1],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 1, 1, 0, 1],
|
||||
[1, 1, 0, 1, 1],
|
||||
],
|
||||
],
|
||||
// shape (4, 5)
|
||||
&[
|
||||
[0_u8, 1, 0, 1, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[1, 1, 0, 1, 1],
|
||||
[1, 1, 0, 1, 0],
|
||||
],
|
||||
// shape (3, 4, 5)
|
||||
&[
|
||||
[
|
||||
[0_u8, 1, 0, 1, 1],
|
||||
[0, 1, 1, 1, 1],
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1],
|
||||
],
|
||||
[
|
||||
[0, 0, 0, 0, 1],
|
||||
[1, 1, 1, 0, 1],
|
||||
[1, 0, 1, 0, 1],
|
||||
[1, 1, 0, 1, 1],
|
||||
],
|
||||
[
|
||||
[1, 0, 0, 0, 1],
|
||||
[0, 0, 1, 1, 1],
|
||||
[1, 0, 1, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
],
|
||||
],
|
||||
)?;
|
||||
|
||||
// 4d vs 2d
|
||||
test(
|
||||
// Shape (2, 3, 3, 4)
|
||||
&[
|
||||
[
|
||||
[[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]],
|
||||
[[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]],
|
||||
[[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]],
|
||||
],
|
||||
[
|
||||
[[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]],
|
||||
[[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]],
|
||||
[[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]],
|
||||
],
|
||||
],
|
||||
// shape (3, 4)
|
||||
&[[0_u8, 0, 1, 1], [1, 1, 1, 1], [0, 1, 0, 1]],
|
||||
// shape (2, 3, 3, 4)
|
||||
&[
|
||||
[
|
||||
[[1_u8, 0, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]],
|
||||
[[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0]],
|
||||
[[1, 0, 1, 1], [0, 0, 0, 1], [0, 1, 1, 0]],
|
||||
],
|
||||
[
|
||||
[[0, 1, 1, 0], [0, 0, 1, 0], [1, 1, 1, 0]],
|
||||
[[1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 0]],
|
||||
[[1, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0]],
|
||||
],
|
||||
],
|
||||
)?;
|
||||
|
||||
// 4d vs 3d
|
||||
test(
|
||||
// Shape (2, 3, 3, 4)
|
||||
&[
|
||||
[
|
||||
[[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]],
|
||||
[[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]],
|
||||
[[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]],
|
||||
],
|
||||
[
|
||||
[[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]],
|
||||
[[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]],
|
||||
[[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]],
|
||||
],
|
||||
],
|
||||
// shape (3, 3, 4)
|
||||
&[
|
||||
[[1_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 0, 0]],
|
||||
[[0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]],
|
||||
[[0, 1, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]],
|
||||
],
|
||||
// shape (2, 3, 3, 4)
|
||||
&[
|
||||
[
|
||||
[[0_u8, 1, 0, 1], [1, 1, 1, 1], [0, 0, 0, 0]],
|
||||
[[1, 0, 0, 1], [0, 1, 0, 0], [1, 1, 0, 0]],
|
||||
[[1, 1, 1, 0], [0, 1, 0, 1], [1, 1, 1, 0]],
|
||||
],
|
||||
[
|
||||
[[1, 0, 0, 1], [1, 1, 1, 0], [1, 1, 1, 1]],
|
||||
[[1, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]],
|
||||
[[1, 1, 1, 0], [0, 1, 1, 1], [1, 0, 0, 0]],
|
||||
],
|
||||
],
|
||||
)?;
|
||||
|
||||
// 4d vs 4d
|
||||
test(
|
||||
// Shape (1, 4, 1, 2)
|
||||
&[[[[1_u8, 0]], [[1, 0]], [[1, 0]], [[1, 1]]]],
|
||||
// shape (2, 1, 4, 2)
|
||||
&[
|
||||
[[[0_u8, 0], [1, 1], [1, 1], [1, 1]]],
|
||||
[[[0, 1], [1, 0], [0, 1], [0, 0]]],
|
||||
],
|
||||
// shape (2, 4, 4, 2)
|
||||
&[
|
||||
[
|
||||
[[1_u8, 0], [0, 1], [0, 1], [0, 1]],
|
||||
[[1, 0], [0, 1], [0, 1], [0, 1]],
|
||||
[[1, 0], [0, 1], [0, 1], [0, 1]],
|
||||
[[1, 1], [0, 0], [0, 0], [0, 0]],
|
||||
],
|
||||
[
|
||||
[[1, 1], [0, 0], [1, 1], [1, 0]],
|
||||
[[1, 1], [0, 0], [1, 1], [1, 0]],
|
||||
[[1, 1], [0, 0], [1, 1], [1, 0]],
|
||||
[[1, 0], [0, 1], [1, 0], [1, 1]],
|
||||
],
|
||||
],
|
||||
)?;
|
||||
|
||||
fn test(input: impl NdArray, other: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Xor".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let inputs: HashMap<String, Tensor> = HashMap::from([
|
||||
(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?),
|
||||
(INPUT_Y.to_string(), Tensor::new(other, &Device::Cpu)?),
|
||||
]);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
|
||||
match expected.dims().len() {
|
||||
0 => {
|
||||
assert_eq!(z.to_vec0::<u8>()?, expected.to_vec0::<u8>()?)
|
||||
}
|
||||
1 => {
|
||||
assert_eq!(z.to_vec1::<u8>()?, expected.to_vec1::<u8>()?)
|
||||
}
|
||||
2 => {
|
||||
assert_eq!(z.to_vec2::<u8>()?, expected.to_vec2::<u8>()?)
|
||||
}
|
||||
3 => {
|
||||
assert_eq!(z.to_vec3::<u8>()?, expected.to_vec3::<u8>()?)
|
||||
}
|
||||
4 => {
|
||||
// Candle has no method equivallent to `to_vec4()`
|
||||
// So, as a hack, we flatten it to a single dim vec to test the results
|
||||
assert_eq!(
|
||||
z.flatten_all()?.to_vec1::<u8>()?,
|
||||
expected.flatten_all()?.to_vec1::<u8>()?
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user