Added tests for ArgMax

This commit is contained in:
b1rtek
2024-05-09 20:45:53 +02:00
parent 8f1119b3e0
commit 4de76b89a2

View File

@ -2612,3 +2612,199 @@ fn test_argmin() -> Result<()> {
Ok(())
}
// "ArgMin"
#[test]
fn test_argmax() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6
// default_axes_keepdims
test(
&[
[2u32, 1u32],
[3u32, 10u32]
],
None,
Some(1),
None,
&[
[1i64, 1i64],
],
)?;
// keepdims
test(
&[
[2u32, 1u32],
[3u32, 10u32]
],
Some(1),
Some(1),
None,
&[
[0i64],
[1i64]
],
)?;
// // negative_axis_keepdims
test(
&[
[2u32, 1u32],
[3u32, 10u32]
],
Some(-1),
Some(1),
None,
&[
[0i64],
[1i64]
],
)?;
// no_keepdims
test(
&[
[2u32, 1u32],
[3u32, 10u32]
],
None,
Some(0),
None,
&[1i64, 1i64],
)?;
// tests from https://pytorch.org/docs/stable/generated/torch.argmax.html
test(
&[
[1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]
],
Some(1),
Some(0),
None,
&[0i64, 2i64, 0i64, 1i64],
)?;
test(
&[
[1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]
],
Some(1),
None,
None,
&[[0i64], [2i64], [0i64], [1i64]],
)?;
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
let att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
i: axis.unwrap_or(0),
doc_string: "axis".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_keepdims = AttributeProto {
name: "keepdims".to_string(),
ref_attr_name: "keepdims".to_string(),
i: keepdims.unwrap_or(1),
doc_string: "keepdims".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_select_last_index = AttributeProto {
name: "select_last_index".to_string(),
ref_attr_name: "select_last_index".to_string(),
i: select_last_index.unwrap_or(0),
doc_string: "select_last_index".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let attrs = {
let mut mut_attrs = vec![];
if axis.is_some() {
mut_attrs.push(att_axis);
}
if keepdims.is_some() {
mut_attrs.push(att_keepdims);
}
if select_last_index.is_some() {
mut_attrs.push(att_select_last_index);
}
mut_attrs
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ArgMax".to_string(),
domain: "".to_string(),
attribute: attrs,
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}