diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index ed9af9ae..09a6edb9 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -10,7 +10,6 @@ use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; -use candle_onnx::eval::Value; const INPUT_X: &str = "x"; const INPUT_Y: &str = "y"; @@ -2432,7 +2431,7 @@ fn test_argmin() -> Result<()> { Some(1), None, &[ - [0u32, 0u32], + [0i64, 0i64], ], )?; // keepdims @@ -2445,8 +2444,8 @@ fn test_argmin() -> Result<()> { Some(1), None, &[ - [1u32], - [0u32] + [1i64], + [0i64] ], )?; // // negative_axis_keepdims @@ -2459,8 +2458,8 @@ fn test_argmin() -> Result<()> { Some(1), None, &[ - [1u32], - [0u32] + [1i64], + [0i64] ], )?; // no_keepdims @@ -2472,7 +2471,32 @@ fn test_argmin() -> Result<()> { None, Some(0), None, - &[0u32, 0u32], + &[0i64, 0i64], + )?; + // tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin + test( + &[ + [0.1139, 0.2254, -0.1381, 0.3687], + [1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [1.7809, -1.2960, 0.9384, 0.1438] + ], + Some(1), + Some(0), + None, + &[2i64, 1i64, 3i64, 1i64], + )?; + test( + &[ + [0.1139, 0.2254, -0.1381, 0.3687], + [1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [1.7809, -1.2960, 0.9384, 0.1438] + ], + Some(1), + None, + None, + &[[2i64], [1i64], [3i64], [1i64]], )?; fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { let att_axis = AttributeProto { @@ -2578,8 +2602,8 @@ fn test_argmin() -> Result<()> { let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { - 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), - 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), _ => unreachable!(), };