mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Added tests from pytorch examples
This commit is contained in:
@ -10,7 +10,6 @@ use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
use candle_onnx::eval::Value;
|
||||
|
||||
const INPUT_X: &str = "x";
|
||||
const INPUT_Y: &str = "y";
|
||||
@ -2432,7 +2431,7 @@ fn test_argmin() -> Result<()> {
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[0u32, 0u32],
|
||||
[0i64, 0i64],
|
||||
],
|
||||
)?;
|
||||
// keepdims
|
||||
@ -2445,8 +2444,8 @@ fn test_argmin() -> Result<()> {
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[1u32],
|
||||
[0u32]
|
||||
[1i64],
|
||||
[0i64]
|
||||
],
|
||||
)?;
|
||||
// // negative_axis_keepdims
|
||||
@ -2459,8 +2458,8 @@ fn test_argmin() -> Result<()> {
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[1u32],
|
||||
[0u32]
|
||||
[1i64],
|
||||
[0i64]
|
||||
],
|
||||
)?;
|
||||
// no_keepdims
|
||||
@ -2472,7 +2471,32 @@ fn test_argmin() -> Result<()> {
|
||||
None,
|
||||
Some(0),
|
||||
None,
|
||||
&[0u32, 0u32],
|
||||
&[0i64, 0i64],
|
||||
)?;
|
||||
// tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin
|
||||
test(
|
||||
&[
|
||||
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||
],
|
||||
Some(1),
|
||||
Some(0),
|
||||
None,
|
||||
&[2i64, 1i64, 3i64, 1i64],
|
||||
)?;
|
||||
test(
|
||||
&[
|
||||
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||
],
|
||||
Some(1),
|
||||
None,
|
||||
None,
|
||||
&[[2i64], [1i64], [3i64], [1i64]],
|
||||
)?;
|
||||
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
|
||||
let att_axis = AttributeProto {
|
||||
@ -2578,8 +2602,8 @@ fn test_argmin() -> Result<()> {
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
1 => assert_eq!(z.to_vec1::<u32>()?, expected.to_vec1::<u32>()?),
|
||||
2 => assert_eq!(z.to_vec2::<u32>()?, expected.to_vec2::<u32>()?),
|
||||
1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user