Added tests from pytorch examples

This commit is contained in:
b1rtek
2024-05-09 20:22:34 +02:00
parent 9a273196b7
commit c4743aa570

View File

@ -10,7 +10,6 @@ use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap; use std::collections::HashMap;
use candle_onnx::eval::Value;
const INPUT_X: &str = "x"; const INPUT_X: &str = "x";
const INPUT_Y: &str = "y"; const INPUT_Y: &str = "y";
@ -2432,7 +2431,7 @@ fn test_argmin() -> Result<()> {
Some(1), Some(1),
None, None,
&[ &[
[0u32, 0u32], [0i64, 0i64],
], ],
)?; )?;
// keepdims // keepdims
@ -2445,8 +2444,8 @@ fn test_argmin() -> Result<()> {
Some(1), Some(1),
None, None,
&[ &[
[1u32], [1i64],
[0u32] [0i64]
], ],
)?; )?;
// // negative_axis_keepdims // // negative_axis_keepdims
@ -2459,8 +2458,8 @@ fn test_argmin() -> Result<()> {
Some(1), Some(1),
None, None,
&[ &[
[1u32], [1i64],
[0u32] [0i64]
], ],
)?; )?;
// no_keepdims // no_keepdims
@ -2472,7 +2471,32 @@ fn test_argmin() -> Result<()> {
None, None,
Some(0), Some(0),
None, None,
&[0u32, 0u32], &[0i64, 0i64],
)?;
// tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin
test(
&[
[0.1139, 0.2254, -0.1381, 0.3687],
[1.0100, -1.1975, -0.0102, -0.4732],
[-0.9240, 0.1207, -0.7506, -1.0213],
[1.7809, -1.2960, 0.9384, 0.1438]
],
Some(1),
Some(0),
None,
&[2i64, 1i64, 3i64, 1i64],
)?;
test(
&[
[0.1139, 0.2254, -0.1381, 0.3687],
[1.0100, -1.1975, -0.0102, -0.4732],
[-0.9240, 0.1207, -0.7506, -1.0213],
[1.7809, -1.2960, 0.9384, 0.1438]
],
Some(1),
None,
None,
&[[2i64], [1i64], [3i64], [1i64]],
)?; )?;
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> { fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
let att_axis = AttributeProto { let att_axis = AttributeProto {
@ -2578,8 +2602,8 @@ fn test_argmin() -> Result<()> {
let expected = Tensor::new(expected, &Device::Cpu)?; let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() { match expected.dims().len() {
1 => assert_eq!(z.to_vec1::<u32>()?, expected.to_vec1::<u32>()?), 1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
2 => assert_eq!(z.to_vec2::<u32>()?, expected.to_vec2::<u32>()?), 2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
_ => unreachable!(), _ => unreachable!(),
}; };