mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
ONNX: GatherElements, Xor (#2568)
This commit is contained in:

committed by
GitHub

parent
dcd83336b6
commit
7c09215ef4
@ -670,6 +670,49 @@ fn simple_eval_(
|
|||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
|
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
|
||||||
|
// A Note to fellow lurkers:
|
||||||
|
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
|
||||||
|
// and examples is incorrect.
|
||||||
|
// Use `torch.gather` for the validating/ verifying against the proper behaviour
|
||||||
|
"GatherElements" => {
|
||||||
|
let data = get(&node.input[0])?;
|
||||||
|
let indices = get(&node.input[1])?;
|
||||||
|
|
||||||
|
let rank = data.rank();
|
||||||
|
if rank != indices.rank() {
|
||||||
|
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
|
||||||
|
}
|
||||||
|
|
||||||
|
let axis = {
|
||||||
|
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||||
|
let axis = data.normalize_axis(axis_i64)?;
|
||||||
|
|
||||||
|
if axis >= rank {
|
||||||
|
bail!(
|
||||||
|
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
|
||||||
|
axis_i64,
|
||||||
|
rank - 1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
axis
|
||||||
|
};
|
||||||
|
|
||||||
|
// index_select does not support negative indices, so normalize them
|
||||||
|
// to positive indices.
|
||||||
|
let indices = &{
|
||||||
|
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
|
||||||
|
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
|
||||||
|
.to_dtype(indices.dtype())?;
|
||||||
|
let mask = indices.lt(&zeros)?;
|
||||||
|
mask.to_dtype(indices.dtype())?
|
||||||
|
.broadcast_mul(&max)?
|
||||||
|
.add(indices)?
|
||||||
|
};
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), data.gather(indices, axis)?);
|
||||||
|
}
|
||||||
"Shape" => {
|
"Shape" => {
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
||||||
let xs = get(&node.input[0])?;
|
let xs = get(&node.input[0])?;
|
||||||
@ -1891,6 +1934,16 @@ fn simple_eval_(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// https://onnx.ai/onnx/operators/onnx__Xor.html
|
||||||
|
"Xor" => {
|
||||||
|
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
|
||||||
|
let a = get(&node.input[0])?.gt(0_u8)?;
|
||||||
|
let b = get(&node.input[1])?.gt(0_u8)?;
|
||||||
|
|
||||||
|
let out = a.broadcast_add(&b)?.eq(1_u8)?;
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), out);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1159,6 +1159,163 @@ fn test_gather_operation() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GatherElements
|
||||||
|
#[test]
|
||||||
|
fn test_gather_elements() -> Result<()> {
|
||||||
|
// all the tests below are verified against `torch.gather()`
|
||||||
|
|
||||||
|
// Rank 1 index
|
||||||
|
test(&[1.0, 2.0, 3.0, 4.0], &[3i64], 0, &[4.0])?;
|
||||||
|
|
||||||
|
// Rank 2 index
|
||||||
|
test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 1, &[[4.0]])?;
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_0
|
||||||
|
test(
|
||||||
|
&[[1., 2.], [3., 4.]],
|
||||||
|
&[[0i64, 0], [1, 0]],
|
||||||
|
1,
|
||||||
|
&[[1., 1.], [4., 3.]],
|
||||||
|
)?;
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_1
|
||||||
|
test(
|
||||||
|
&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
|
||||||
|
&[[1i64, 2, 0], [2, 0, 0]],
|
||||||
|
0,
|
||||||
|
&[[4., 8., 3.], [7., 2., 3.]],
|
||||||
|
)?;
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_negative_indices
|
||||||
|
test(
|
||||||
|
&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]],
|
||||||
|
&[[-1_i64, -2, 0], [-2, 0, 0]],
|
||||||
|
0,
|
||||||
|
&[[7., 5., 3.], [4., 2., 3.]],
|
||||||
|
)?;
|
||||||
|
test(
|
||||||
|
&[[1.0], [2.0], [3.0], [4.0]],
|
||||||
|
&[[3i64], [2]],
|
||||||
|
0,
|
||||||
|
&[[4.], [3.]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Rank 3
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[[1.0, 2.0], [3.0, 4.0]],
|
||||||
|
[[5.0, 6.0], [7.0, 8.0]],
|
||||||
|
[[9.0, 10.0], [11.0, 12.0]],
|
||||||
|
[[13.0, 14.0], [15.0, 16.0]],
|
||||||
|
],
|
||||||
|
&[[[1i64]]],
|
||||||
|
0,
|
||||||
|
&[[[5.]]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[[1.0, 2.0], [3.0, 4.0]],
|
||||||
|
[[5.0, 6.0], [7.0, 8.0]],
|
||||||
|
[[9.0, 10.0], [11.0, 12.0]],
|
||||||
|
[[13.0, 14.0], [15.0, 16.0]],
|
||||||
|
],
|
||||||
|
&[[[1i64]]],
|
||||||
|
1,
|
||||||
|
&[[[3.]]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[[1.0, 2.0], [3.0, 4.0]],
|
||||||
|
[[5.0, 6.0], [7.0, 8.0]],
|
||||||
|
[[9.0, 10.0], [11.0, 12.0]],
|
||||||
|
[[13.0, 14.0], [15.0, 16.0]],
|
||||||
|
],
|
||||||
|
&[[[1i64], [0]]],
|
||||||
|
2,
|
||||||
|
&[[[2.], [3.]]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Error cases
|
||||||
|
// Invalid index
|
||||||
|
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 0, &[[1., 2., 3., 4.]]).is_err());
|
||||||
|
// Invalid axis/ dim
|
||||||
|
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 2, &[[1., 2., 3., 4.]]).is_err());
|
||||||
|
// Invalid rank
|
||||||
|
assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[3i64], 0, &[[1.]]).is_err());
|
||||||
|
|
||||||
|
fn test(
|
||||||
|
data: impl NdArray,
|
||||||
|
indices: impl NdArray,
|
||||||
|
axis: i64,
|
||||||
|
expected: impl NdArray,
|
||||||
|
) -> Result<()> {
|
||||||
|
let att_axis = AttributeProto {
|
||||||
|
name: "axis".to_string(),
|
||||||
|
ref_attr_name: "axis".to_string(),
|
||||||
|
i: axis,
|
||||||
|
doc_string: "axis".to_string(),
|
||||||
|
r#type: 2,
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "GatherElements".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![att_axis],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// "Size"
|
// "Size"
|
||||||
#[test]
|
#[test]
|
||||||
fn test_size_operation() -> Result<()> {
|
fn test_size_operation() -> Result<()> {
|
||||||
@ -5340,3 +5497,375 @@ fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Xor
|
||||||
|
#[test]
|
||||||
|
fn test_xor() -> Result<()> {
|
||||||
|
// tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor
|
||||||
|
|
||||||
|
// 2d
|
||||||
|
test(
|
||||||
|
&[[0_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1]],
|
||||||
|
&[[1_u8, 1, 0, 0], [1, 0, 0, 1], [1, 1, 1, 0]],
|
||||||
|
&[[1_u8, 0, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 3d
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[0_u8, 1, 1, 1, 1],
|
||||||
|
[0, 1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 0, 1, 1, 1],
|
||||||
|
[1, 0, 1, 1, 1],
|
||||||
|
[1, 1, 0, 0, 1],
|
||||||
|
[1, 0, 0, 1, 0],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 0, 0, 1, 1],
|
||||||
|
[1, 1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0, 1],
|
||||||
|
[1, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[1_u8, 0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 0, 1],
|
||||||
|
[1, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 0, 0, 1, 1],
|
||||||
|
[1, 0, 1, 1, 1],
|
||||||
|
[0, 1, 0, 1, 1],
|
||||||
|
[1, 1, 1, 0, 0],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[1, 1, 0, 1, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[1, 1, 0, 1, 0],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[1_u8, 1, 1, 0, 0],
|
||||||
|
[0, 1, 0, 0, 1],
|
||||||
|
[0, 1, 1, 0, 1],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[1, 0, 0, 1, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 1, 1, 0, 1],
|
||||||
|
[0, 0, 1, 1, 0],
|
||||||
|
[1, 0, 1, 1, 1],
|
||||||
|
[0, 1, 0, 1, 1],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 4d
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[0_u8, 1, 1, 0], [1, 0, 0, 0], [1, 1, 0, 1]],
|
||||||
|
[[1, 1, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 0]],
|
||||||
|
[[1, 0, 0, 1], [1, 0, 1, 1], [1, 1, 0, 1]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[1_u8, 0, 1, 0], [0, 0, 1, 1], [1, 0, 1, 0]],
|
||||||
|
[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[1, 1, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0]],
|
||||||
|
[[0, 0, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[1_u8, 1, 0, 0], [1, 0, 1, 1], [0, 1, 1, 1]],
|
||||||
|
[[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 0]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[0, 0, 1, 0], [1, 0, 1, 1], [1, 0, 1, 0]],
|
||||||
|
[[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 1, 0]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor_broadcast
|
||||||
|
// 3d vs 1d
|
||||||
|
test(
|
||||||
|
// Shape (3, 4, 5)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[0_u8, 0, 0, 0, 1],
|
||||||
|
[0, 1, 0, 1, 1],
|
||||||
|
[1, 0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 0, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 1, 0, 1, 1],
|
||||||
|
[1, 1, 0, 0, 1],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1],
|
||||||
|
[0, 1, 1, 0, 1],
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
// shape (5)
|
||||||
|
&[1_u8, 0, 0, 1, 1],
|
||||||
|
// shape (3, 4, 5)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[1_u8, 0, 0, 1, 0],
|
||||||
|
[1, 1, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[1, 0, 1, 1, 0],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 1, 0, 0, 0],
|
||||||
|
[0, 1, 0, 1, 0],
|
||||||
|
[1, 1, 1, 0, 1],
|
||||||
|
[1, 0, 0, 1, 0],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
[1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 0],
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 3d vs 2d
|
||||||
|
test(
|
||||||
|
// Shape (3, 4, 5)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[0_u8, 0, 0, 0, 1],
|
||||||
|
[0, 1, 0, 1, 1],
|
||||||
|
[1, 0, 0, 1, 1],
|
||||||
|
[0, 0, 1, 0, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 1, 0, 1, 1],
|
||||||
|
[1, 1, 0, 0, 1],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
[0, 0, 0, 1, 1],
|
||||||
|
[0, 1, 1, 0, 1],
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
// shape (4, 5)
|
||||||
|
&[
|
||||||
|
[0_u8, 1, 0, 1, 0],
|
||||||
|
[0, 0, 1, 0, 0],
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
[1, 1, 0, 1, 0],
|
||||||
|
],
|
||||||
|
// shape (3, 4, 5)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[0_u8, 1, 0, 1, 1],
|
||||||
|
[0, 1, 1, 1, 1],
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
[1, 1, 1, 0, 1],
|
||||||
|
[1, 0, 1, 0, 1],
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0, 1],
|
||||||
|
[0, 0, 1, 1, 1],
|
||||||
|
[1, 0, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 4d vs 2d
|
||||||
|
test(
|
||||||
|
// Shape (2, 3, 3, 4)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]],
|
||||||
|
[[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]],
|
||||||
|
[[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]],
|
||||||
|
[[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]],
|
||||||
|
[[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
// shape (3, 4)
|
||||||
|
&[[0_u8, 0, 1, 1], [1, 1, 1, 1], [0, 1, 0, 1]],
|
||||||
|
// shape (2, 3, 3, 4)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[1_u8, 0, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]],
|
||||||
|
[[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0]],
|
||||||
|
[[1, 0, 1, 1], [0, 0, 0, 1], [0, 1, 1, 0]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[0, 1, 1, 0], [0, 0, 1, 0], [1, 1, 1, 0]],
|
||||||
|
[[1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 0]],
|
||||||
|
[[1, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 4d vs 3d
|
||||||
|
test(
|
||||||
|
// Shape (2, 3, 3, 4)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]],
|
||||||
|
[[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]],
|
||||||
|
[[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]],
|
||||||
|
[[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]],
|
||||||
|
[[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
// shape (3, 3, 4)
|
||||||
|
&[
|
||||||
|
[[1_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 0, 0]],
|
||||||
|
[[0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]],
|
||||||
|
[[0, 1, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]],
|
||||||
|
],
|
||||||
|
// shape (2, 3, 3, 4)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[0_u8, 1, 0, 1], [1, 1, 1, 1], [0, 0, 0, 0]],
|
||||||
|
[[1, 0, 0, 1], [0, 1, 0, 0], [1, 1, 0, 0]],
|
||||||
|
[[1, 1, 1, 0], [0, 1, 0, 1], [1, 1, 1, 0]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[1, 0, 0, 1], [1, 1, 1, 0], [1, 1, 1, 1]],
|
||||||
|
[[1, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]],
|
||||||
|
[[1, 1, 1, 0], [0, 1, 1, 1], [1, 0, 0, 0]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 4d vs 4d
|
||||||
|
test(
|
||||||
|
// Shape (1, 4, 1, 2)
|
||||||
|
&[[[[1_u8, 0]], [[1, 0]], [[1, 0]], [[1, 1]]]],
|
||||||
|
// shape (2, 1, 4, 2)
|
||||||
|
&[
|
||||||
|
[[[0_u8, 0], [1, 1], [1, 1], [1, 1]]],
|
||||||
|
[[[0, 1], [1, 0], [0, 1], [0, 0]]],
|
||||||
|
],
|
||||||
|
// shape (2, 4, 4, 2)
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[[1_u8, 0], [0, 1], [0, 1], [0, 1]],
|
||||||
|
[[1, 0], [0, 1], [0, 1], [0, 1]],
|
||||||
|
[[1, 0], [0, 1], [0, 1], [0, 1]],
|
||||||
|
[[1, 1], [0, 0], [0, 0], [0, 0]],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[[1, 1], [0, 0], [1, 1], [1, 0]],
|
||||||
|
[[1, 1], [0, 0], [1, 1], [1, 0]],
|
||||||
|
[[1, 1], [0, 0], [1, 1], [1, 0]],
|
||||||
|
[[1, 0], [0, 1], [1, 0], [1, 1]],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
fn test(input: impl NdArray, other: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Xor".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let inputs: HashMap<String, Tensor> = HashMap::from([
|
||||||
|
(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?),
|
||||||
|
(INPUT_Y.to_string(), Tensor::new(other, &Device::Cpu)?),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => {
|
||||||
|
assert_eq!(z.to_vec0::<u8>()?, expected.to_vec0::<u8>()?)
|
||||||
|
}
|
||||||
|
1 => {
|
||||||
|
assert_eq!(z.to_vec1::<u8>()?, expected.to_vec1::<u8>()?)
|
||||||
|
}
|
||||||
|
2 => {
|
||||||
|
assert_eq!(z.to_vec2::<u8>()?, expected.to_vec2::<u8>()?)
|
||||||
|
}
|
||||||
|
3 => {
|
||||||
|
assert_eq!(z.to_vec3::<u8>()?, expected.to_vec3::<u8>()?)
|
||||||
|
}
|
||||||
|
4 => {
|
||||||
|
// Candle has no method equivallent to `to_vec4()`
|
||||||
|
// So, as a hack, we flatten it to a single dim vec to test the results
|
||||||
|
assert_eq!(
|
||||||
|
z.flatten_all()?.to_vec1::<u8>()?,
|
||||||
|
expected.flatten_all()?.to_vec1::<u8>()?
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user