mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Added tests from pytorch examples
This commit is contained in:
@ -10,7 +10,6 @@ use candle_onnx::onnx::attribute_proto::AttributeType;
|
|||||||
use candle_onnx::onnx::tensor_proto::DataType;
|
use candle_onnx::onnx::tensor_proto::DataType;
|
||||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use candle_onnx::eval::Value;
|
|
||||||
|
|
||||||
const INPUT_X: &str = "x";
|
const INPUT_X: &str = "x";
|
||||||
const INPUT_Y: &str = "y";
|
const INPUT_Y: &str = "y";
|
||||||
@ -2432,7 +2431,7 @@ fn test_argmin() -> Result<()> {
|
|||||||
Some(1),
|
Some(1),
|
||||||
None,
|
None,
|
||||||
&[
|
&[
|
||||||
[0u32, 0u32],
|
[0i64, 0i64],
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
// keepdims
|
// keepdims
|
||||||
@ -2445,8 +2444,8 @@ fn test_argmin() -> Result<()> {
|
|||||||
Some(1),
|
Some(1),
|
||||||
None,
|
None,
|
||||||
&[
|
&[
|
||||||
[1u32],
|
[1i64],
|
||||||
[0u32]
|
[0i64]
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
// // negative_axis_keepdims
|
// // negative_axis_keepdims
|
||||||
@ -2459,8 +2458,8 @@ fn test_argmin() -> Result<()> {
|
|||||||
Some(1),
|
Some(1),
|
||||||
None,
|
None,
|
||||||
&[
|
&[
|
||||||
[1u32],
|
[1i64],
|
||||||
[0u32]
|
[0i64]
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
// no_keepdims
|
// no_keepdims
|
||||||
@ -2472,7 +2471,32 @@ fn test_argmin() -> Result<()> {
|
|||||||
None,
|
None,
|
||||||
Some(0),
|
Some(0),
|
||||||
None,
|
None,
|
||||||
&[0u32, 0u32],
|
&[0i64, 0i64],
|
||||||
|
)?;
|
||||||
|
// tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||||
|
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||||
|
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||||
|
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
Some(0),
|
||||||
|
None,
|
||||||
|
&[2i64, 1i64, 3i64, 1i64],
|
||||||
|
)?;
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||||
|
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||||
|
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||||
|
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
&[[2i64], [1i64], [3i64], [1i64]],
|
||||||
)?;
|
)?;
|
||||||
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
|
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
|
||||||
let att_axis = AttributeProto {
|
let att_axis = AttributeProto {
|
||||||
@ -2578,8 +2602,8 @@ fn test_argmin() -> Result<()> {
|
|||||||
|
|
||||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
match expected.dims().len() {
|
match expected.dims().len() {
|
||||||
1 => assert_eq!(z.to_vec1::<u32>()?, expected.to_vec1::<u32>()?),
|
1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
|
||||||
2 => assert_eq!(z.to_vec2::<u32>()?, expected.to_vec2::<u32>()?),
|
2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user